Skip to content

Commit

Permalink
benchmark.subset()
Browse files Browse the repository at this point in the history
  • Loading branch information
gasse committed Oct 16, 2024
1 parent 36ba273 commit 620b5f4
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 27 deletions.
102 changes: 75 additions & 27 deletions browsergym/experiments/src/browsergym/experiments/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
import logging
import pathlib
import pkgutil
from dataclasses import dataclass
from typing import Literal
from dataclasses import dataclass, field
from typing import Literal, Optional

import numpy as np
import pandas as pd
from dataclasses_json import DataClassJsonMixin
from dataclasses_json import DataClassJsonMixin, config

from browsergym.core.action.highlevel import HighLevelActionSet
from browsergym.experiments.loop import SEED_MAX, EnvArgs
Expand Down Expand Up @@ -40,28 +40,63 @@ class Benchmark(DataClassJsonMixin):
name: str
high_level_action_set_args: HighLevelActionSetArgs
env_args_list: list[EnvArgs]
task_metadata: Optional[pd.DataFrame] = field(
default_factory=lambda: None,
metadata=config(
encoder=lambda df: df.to_dict(orient="records") if df is not None else None,
decoder=lambda items: pd.DataFrame(items) if items is not None else None,
),
)

def __post_init__(self):
# if no metadata is present, generate a dataframe with single "task_name" column
if self.task_metadata is None:
unique_task_names = list(set([env_args.task_name for env_args in self.env_args_list]))
self.task_metadata = pd.DataFrame(
[{"task_name": task_name} for task_name in unique_task_names]
)
# make sure all tasks in env_args are in the metadata
metadata_tasks = list(self.task_metadata["task_name"])
assert all([env_args.task_name in metadata_tasks for env_args in self.env_args_list])

def subset(self, task_filter: dict[str, str]):
# extract the filtered task_name subset
task_name_subset = task_list_from_metadata(self.task_metadata, task_filter)

# return the sub benchmark
filter_str = ",".join([f"{col_name}=/{regex}/" for col_name, regex in task_filter.items()])
return Benchmark(
name=f"{self.name}[{filter_str}]",
high_level_action_set_args=self.high_level_action_set_args,
env_args_list=[
env_args
for env_args in self.env_args_list
if env_args.task_name in task_name_subset
],
task_metadata=self.task_metadata,
)


def task_list_from_metadata(benchmark: str, *args, **kwargs):
return task_list_from_csv(
io.StringIO(pkgutil.get_data(__name__, f"task_metadata/{benchmark}.csv").decode("utf-8")),
*args,
**kwargs,
def task_metadata(benchmark_name: str):
return task_metadata_from_csv(
io.StringIO(
pkgutil.get_data(__name__, f"task_metadata/{benchmark_name}.csv").decode("utf-8")
)
)


def task_list_from_csv(
filepath_or_bytes: str | pathlib.Path | io.IOBase,
filters: dict[str, str] = {},
):
# read task list with metadata
df: pd.DataFrame = pd.read_csv(filepath_or_bytes)
def task_metadata_from_csv(filepath):
return pd.read_csv(filepath).fillna("")


def task_list_from_metadata(metadata: pd.DataFrame, filter: dict[str, str] = {}):
df = metadata
# filter the desired columns (AND filter)
for col_name, regex in filters.items():
filter = df[col_name].astype(str).str.contains(regex, regex=True)
df = df[filter]
for col_name, regex in filter.items():
col_filter = df[col_name].astype(str).str.contains(regex, regex=True)
df = df[col_filter]
# return only the task names
return df["task_name"]
return list(df["task_name"])


# These are mean as the default highlevel action set to fairly evaluate agents on each benchmark.
Expand Down Expand Up @@ -117,23 +152,25 @@ def task_list_from_csv(
name="miniwob_all",
high_level_action_set_args=DEFAULT_HIGHLEVEL_ACTION_SET_ARGS["miniwob"],
env_args_list=_make_env_args_list_from_repeat_tasks(
task_list=task_list_from_metadata(benchmark="miniwob"),
task_list=task_list_from_metadata(metadata=task_metadata("miniwob")),
max_steps=10,
n_repeats=10,
seeds_rng=np.random.RandomState(42),
),
task_metadata=task_metadata("miniwob"),
),
"miniwob_webgum": lambda: Benchmark(
name="miniwob_webgum",
high_level_action_set_args=DEFAULT_HIGHLEVEL_ACTION_SET_ARGS["miniwob"],
env_args_list=_make_env_args_list_from_repeat_tasks(
task_list=task_list_from_metadata(
benchmark="miniwob", filters={"webgum_subset": "True"}
metadata=task_metadata("miniwob"), filter={"webgum_subset": "True"}
),
max_steps=10,
n_repeats=10,
seeds_rng=np.random.RandomState(42),
),
task_metadata=task_metadata("miniwob"),
),
"miniwob_tiny_test": lambda: Benchmark(
name="miniwob_tiny_test",
Expand All @@ -144,73 +181,82 @@ def task_list_from_csv(
n_repeats=2,
seeds_rng=np.random.RandomState(42),
),
task_metadata=task_metadata("miniwob"),
),
"miniwob_train": lambda: Benchmark(
name="miniwob_train",
high_level_action_set_args=DEFAULT_HIGHLEVEL_ACTION_SET_ARGS["miniwob"],
env_args_list=_make_env_args_list_from_repeat_tasks(
task_list=task_list_from_metadata(
benchmark="miniwob",
filters={"miniwob_category": "original|nodelay|debug|additional"},
metadata=task_metadata("miniwob"),
filter={"miniwob_category": "original|nodelay|debug|additional"},
),
max_steps=10,
n_repeats=10,
seeds_rng=np.random.RandomState(42),
),
task_metadata=task_metadata("miniwob"),
),
"miniwob_test": lambda: Benchmark(
name="miniwob_test",
high_level_action_set_args=DEFAULT_HIGHLEVEL_ACTION_SET_ARGS["miniwob"],
env_args_list=_make_env_args_list_from_repeat_tasks(
task_list=task_list_from_metadata(
benchmark="miniwob", filters={"miniwob_category": "hidden test"}
metadata=task_metadata("miniwob"), filter={"miniwob_category": "hidden test"}
),
max_steps=10,
n_repeats=10,
seeds_rng=np.random.RandomState(42),
),
task_metadata=task_metadata("miniwob"),
),
"webarena": lambda: Benchmark(
name="webarena",
high_level_action_set_args=DEFAULT_HIGHLEVEL_ACTION_SET_ARGS["webarena"],
env_args_list=_make_env_args_list_from_repeat_tasks(
task_list=task_list_from_metadata(benchmark="webarena"),
task_list=task_list_from_metadata(metadata=task_metadata("webarena")),
max_steps=15,
n_repeats=1,
seeds_rng=np.random.RandomState(42),
),
task_metadata=task_metadata("webarena"),
),
"visualwebarena": lambda: Benchmark(
name="visualwebarena",
high_level_action_set_args=DEFAULT_HIGHLEVEL_ACTION_SET_ARGS["visualwebarena"],
env_args_list=_make_env_args_list_from_repeat_tasks(
task_list=task_list_from_metadata(benchmark="visualwebarena"),
task_list=task_list_from_metadata(metadata=task_metadata("visualwebarena")),
max_steps=15,
n_repeats=1,
seeds_rng=np.random.RandomState(42),
),
task_metadata=task_metadata("visualwebarena"),
),
"workarena_l1": lambda: Benchmark(
name="workarena_l1",
high_level_action_set_args=DEFAULT_HIGHLEVEL_ACTION_SET_ARGS["workarena_l1"],
env_args_list=_make_env_args_list_from_repeat_tasks(
task_list=task_list_from_metadata(benchmark="workarena", filters={"level": "l1"}),
task_list=task_list_from_metadata(
metadata=task_metadata("workarena"), filter={"level": "l1"}
),
max_steps=15,
n_repeats=10,
seeds_rng=np.random.RandomState(42),
),
task_metadata=task_metadata("workarena"),
),
"workarena_l1_sort": lambda: Benchmark(
name="workarena_l1_sort",
high_level_action_set_args=DEFAULT_HIGHLEVEL_ACTION_SET_ARGS["workarena_l1"],
env_args_list=_make_env_args_list_from_repeat_tasks(
task_list=task_list_from_metadata(
benchmark="workarena", filters={"level": "l1", "category": "list-sort"}
metadata=task_metadata("workarena"), filter={"level": "l1", "category": "list-sort"}
),
max_steps=15,
n_repeats=10,
seeds_rng=np.random.RandomState(42),
),
task_metadata=task_metadata("workarena"),
),
"workarena_l2_agent_curriculum_eval": lambda: Benchmark(
name="workarena_l2_agent_curriculum_eval",
Expand All @@ -222,6 +268,7 @@ def task_list_from_csv(
max_steps=50,
curriculum_type="agent",
),
task_metadata=task_metadata("workarena"),
),
"workarena_l3_agent_curriculum_eval": lambda: Benchmark(
name="workarena_l3_agent_curriculum_eval",
Expand All @@ -233,6 +280,7 @@ def task_list_from_csv(
max_steps=50,
curriculum_type="agent",
),
task_metadata=task_metadata("workarena"),
),
}

Expand Down
24 changes: 24 additions & 0 deletions tests/experiments/test_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,35 @@ def test_build_benchmarks():
benchmark = benchmark_builder()
assert name == benchmark.name
assert benchmark.env_args_list # non-empty
assert benchmark.task_metadata is not None
assert len(benchmark.env_args_list) == expected_bench_size[name]
benchmark_bis = Benchmark.from_json(benchmark.to_json())
assert benchmark.to_dict() == benchmark_bis.to_dict()


def test_benchmark_subset():
benchmark: Benchmark = BENCHMARKS["miniwob_all"]()

benchmark_subset = benchmark.subset(task_filter={"task_name": "click"})
assert len(benchmark_subset.env_args_list) == 31 * 10
assert benchmark_subset.name == "miniwob_all[task_name=/click/]"

benchmark_subset_1 = benchmark_subset.subset(task_filter={"miniwob_category": "original"})
benchmark_subset_2 = benchmark.subset(
task_filter={"task_name": "click", "miniwob_category": "original"}
)

assert benchmark_subset_1.name == "miniwob_all[task_name=/click/][miniwob_category=/original/]"
assert benchmark_subset_2.name == "miniwob_all[task_name=/click/,miniwob_category=/original/]"

dict_1 = benchmark_subset_1.to_dict()
dict_1.pop("name")
dict_2 = benchmark_subset_2.to_dict()
dict_2.pop("name")

assert dict_1 == dict_2


def test_run_mock_benchmark():
benchmark = Benchmark(
name="miniwob_click_test",
Expand Down

0 comments on commit 620b5f4

Please sign in to comment.