Skip to content

Commit

Permalink
minor refactors
Browse files Browse the repository at this point in the history
  • Loading branch information
gasse committed Oct 8, 2024
1 parent 2fb7138 commit 04b44f7
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 15 deletions.
27 changes: 14 additions & 13 deletions browsergym/experiments/src/browsergym/experiments/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def make_action_set(self):
@dataclass
class Benchmark(DataClassJsonMixin):
name: str
high_level_action_set: HighLevelActionSetArgs
high_level_action_set_args: HighLevelActionSetArgs
env_args_list: list[EnvArgs]


Expand Down Expand Up @@ -66,7 +66,7 @@ def task_list_from_csv(

# These are mean as the default highlevel action set to fairly evaluate agents on each benchmark.
# They are mostly arbitrary, the important thing is to evaluate different agents using the same action set for fairness.
DEFAULT_HIGHLEVEL_ACTION_SETS = {
DEFAULT_HIGHLEVEL_ACTION_SET_ARGS = {
"miniwob": HighLevelActionSetArgs(
subsets=["bid", "coord"],
multiaction=False,
Expand Down Expand Up @@ -111,10 +111,11 @@ def task_list_from_csv(
),
}

# all benchmarks are callables designed for lazy loading, i.e. `bench = BENCHMARKS["miniwob_all"]()`
BENCHMARKS = {
"miniwob_all": lambda: Benchmark(
name="miniwob_all",
high_level_action_set=DEFAULT_HIGHLEVEL_ACTION_SETS["miniwob"],
high_level_action_set_args=DEFAULT_HIGHLEVEL_ACTION_SET_ARGS["miniwob"],
env_args_list=_make_env_args_list_from_repeat_tasks(
task_list=task_list_from_metadata(benchmark="miniwob"),
max_steps=10,
Expand All @@ -124,7 +125,7 @@ def task_list_from_csv(
),
"miniwob_webgum": lambda: Benchmark(
name="miniwob_webgum",
high_level_action_set=DEFAULT_HIGHLEVEL_ACTION_SETS["miniwob"],
high_level_action_set_args=DEFAULT_HIGHLEVEL_ACTION_SET_ARGS["miniwob"],
env_args_list=_make_env_args_list_from_repeat_tasks(
task_list=task_list_from_metadata(
benchmark="miniwob", filters={"webgum_subset": "True"}
Expand All @@ -136,7 +137,7 @@ def task_list_from_csv(
),
"miniwob_tiny_test": lambda: Benchmark(
name="miniwob_tiny_test",
high_level_action_set=DEFAULT_HIGHLEVEL_ACTION_SETS["miniwob"],
high_level_action_set_args=DEFAULT_HIGHLEVEL_ACTION_SET_ARGS["miniwob"],
env_args_list=_make_env_args_list_from_repeat_tasks(
task_list=["miniwob.click-dialog", "miniwob.click-checkboxes"],
max_steps=5,
Expand All @@ -146,7 +147,7 @@ def task_list_from_csv(
),
"miniwob_train": lambda: Benchmark(
name="miniwob_train",
high_level_action_set=DEFAULT_HIGHLEVEL_ACTION_SETS["miniwob"],
high_level_action_set_args=DEFAULT_HIGHLEVEL_ACTION_SET_ARGS["miniwob"],
env_args_list=_make_env_args_list_from_repeat_tasks(
task_list=task_list_from_metadata(
benchmark="miniwob",
Expand All @@ -159,7 +160,7 @@ def task_list_from_csv(
),
"miniwob_test": lambda: Benchmark(
name="miniwob_test",
high_level_action_set=DEFAULT_HIGHLEVEL_ACTION_SETS["miniwob"],
high_level_action_set_args=DEFAULT_HIGHLEVEL_ACTION_SET_ARGS["miniwob"],
env_args_list=_make_env_args_list_from_repeat_tasks(
task_list=task_list_from_metadata(
benchmark="miniwob", filters={"miniwob_category": "hidden test"}
Expand All @@ -171,7 +172,7 @@ def task_list_from_csv(
),
"webarena": lambda: Benchmark(
name="webarena",
high_level_action_set=DEFAULT_HIGHLEVEL_ACTION_SETS["webarena"],
high_level_action_set_args=DEFAULT_HIGHLEVEL_ACTION_SET_ARGS["webarena"],
env_args_list=_make_env_args_list_from_repeat_tasks(
task_list=task_list_from_metadata(benchmark="webarena"),
max_steps=15,
Expand All @@ -181,7 +182,7 @@ def task_list_from_csv(
),
"visualwebarena": lambda: Benchmark(
name="visualwebarena",
high_level_action_set=DEFAULT_HIGHLEVEL_ACTION_SETS["visualwebarena"],
high_level_action_set_args=DEFAULT_HIGHLEVEL_ACTION_SET_ARGS["visualwebarena"],
env_args_list=_make_env_args_list_from_repeat_tasks(
task_list=task_list_from_metadata(benchmark="visualwebarena"),
max_steps=15,
Expand All @@ -191,7 +192,7 @@ def task_list_from_csv(
),
"workarena_l1": lambda: Benchmark(
name="workarena_l1",
high_level_action_set=DEFAULT_HIGHLEVEL_ACTION_SETS["workarena_l1"],
high_level_action_set_args=DEFAULT_HIGHLEVEL_ACTION_SET_ARGS["workarena_l1"],
env_args_list=_make_env_args_list_from_repeat_tasks(
task_list=task_list_from_metadata(benchmark="workarena", filters={"level": "l1"}),
max_steps=15,
Expand All @@ -201,7 +202,7 @@ def task_list_from_csv(
),
"workarena_l1_sort": lambda: Benchmark(
name="workarena_l1_sort",
high_level_action_set=DEFAULT_HIGHLEVEL_ACTION_SETS["workarena_l1"],
high_level_action_set_args=DEFAULT_HIGHLEVEL_ACTION_SET_ARGS["workarena_l1"],
env_args_list=_make_env_args_list_from_repeat_tasks(
task_list=task_list_from_metadata(
benchmark="workarena", filters={"level": "l1", "category": "list-sort"}
Expand All @@ -213,7 +214,7 @@ def task_list_from_csv(
),
"workarena_l2_agent_curriculum": lambda: Benchmark(
name="workarena_l2_agent_curriculum",
high_level_action_set=DEFAULT_HIGHLEVEL_ACTION_SETS["workarena"],
high_level_action_set_args=DEFAULT_HIGHLEVEL_ACTION_SET_ARGS["workarena"],
env_args_list=_make_env_args_list_from_workarena_curriculum(
level="l2",
task_category_filter=None,
Expand All @@ -224,7 +225,7 @@ def task_list_from_csv(
),
"workarena_l3_agent_curriculum": lambda: Benchmark(
name="workarena_l3_agent_curriculum",
high_level_action_set=DEFAULT_HIGHLEVEL_ACTION_SETS["workarena"],
high_level_action_set_args=DEFAULT_HIGHLEVEL_ACTION_SET_ARGS["workarena"],
env_args_list=_make_env_args_list_from_workarena_curriculum(
level="l3",
task_category_filter=None,
Expand Down
6 changes: 4 additions & 2 deletions tests/experiments/test_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def test_build_benchmarks():
def test_run_mock_benchmark():
benchmark = Benchmark(
name="miniwob_click_test",
high_level_action_set=HighLevelActionSetArgs(
high_level_action_set_args=HighLevelActionSetArgs(
subsets=["bid"],
multiaction=False,
strict=False,
Expand All @@ -91,7 +91,9 @@ def test_run_mock_benchmark():
)

for env_args in benchmark.env_args_list:
agent_args = MiniwobTestAgentArgs(high_level_action_set=benchmark.high_level_action_set)
agent_args = MiniwobTestAgentArgs(
high_level_action_set=benchmark.high_level_action_set_args
)
exp_args = ExpArgs(
agent_args=agent_args,
env_args=env_args,
Expand Down

0 comments on commit 04b44f7

Please sign in to comment.