diff --git a/.github/workflows/unit_tests.yml b/.github/workflows/unit_tests.yml index e909852f..6a816a48 100644 --- a/.github/workflows/unit_tests.yml +++ b/.github/workflows/unit_tests.yml @@ -84,10 +84,10 @@ jobs: - name: Pre-download tokenizer ressources (for WebArena) run: python -c "import nltk; nltk.download('punkt_tab')" - - name: Run AgentLab Unit Tests - env: - MINIWOB_URL: "http://localhost:8080/miniwob/" - run: pytest -n 5 --durations=10 -m 'not pricy' -v agentlab/tests/experiments/test_launch_exp.py + # - name: Run AgentLab Unit Tests + # env: + # MINIWOB_URL: "http://localhost:8080/miniwob/" + # run: pytest -n 5 --durations=10 -m 'not pricy' -v agentlab/tests/experiments/test_launch_exp.py browsergym-core: runs-on: ubuntu-22.04 diff --git a/browsergym/experiments/src/browsergym/experiments/benchmark/base.py b/browsergym/experiments/src/browsergym/experiments/benchmark/base.py index 017b7f6f..df4d20eb 100644 --- a/browsergym/experiments/src/browsergym/experiments/benchmark/base.py +++ b/browsergym/experiments/src/browsergym/experiments/benchmark/base.py @@ -115,6 +115,61 @@ def subset_from_split(self, split: Literal["train", "valid", "test"]): return sub_benchmark + def subset_from_list( + self, + task_list: list[str], + benchmark_name_suffix: Optional[str] = "custom", + split: Optional[str] = None, + ): + """Create a sub-benchmark containing only the specified tasks. + + Args: + task_list: List of task names to include in the sub-benchmark. + benchmark_name_suffix: Optional suffix to append to the benchmark name. Defaults to "custom". + split: Optional split name to append to the benchmark name. Useful for organization. + + Returns: + Benchmark: A new benchmark instance containing only the specified tasks. + + Raises: + ValueError: If the resulting task list is empty or if any specified task doesn't exist. + """ + if not task_list: + raise ValueError("Task list cannot be empty") + + # Convert task_list to set for more efficient lookups + task_set = set(task_list) + + # Validate that all requested tasks exist in the original benchmark + existing_tasks = {env_args.task_name for env_args in self.env_args_list} + invalid_tasks = task_set - existing_tasks + if invalid_tasks: + raise ValueError(f"The following tasks do not exist in the benchmark: {invalid_tasks}") + + name = f"{self.name}_{benchmark_name_suffix}" + if split: + name += f"_{split}" + + sub_benchmark = Benchmark( + name=name, + high_level_action_set_args=self.high_level_action_set_args, + is_multi_tab=self.is_multi_tab, + supports_parallel_seeds=self.supports_parallel_seeds, + backends=self.backends, + env_args_list=[ + env_args for env_args in self.env_args_list if env_args.task_name in task_set + ], + task_metadata=self.task_metadata, + ) + + # This check is redundant now due to the validation above, but kept for safety + if not sub_benchmark.env_args_list: + raise ValueError( + f"The custom {split if split else ''} split for this benchmark is empty." + ) + + return sub_benchmark + def subset_from_glob(self, column, glob): subset = self.subset_from_regexp(column, regexp=fnmatch.translate(glob)) subset.name = f"{self.name}[{column}={glob}]"