diff --git a/prompt2model/dataset_retriever/__init__.py b/prompt2model/dataset_retriever/__init__.py index 7cc03238d..781797c4e 100644 --- a/prompt2model/dataset_retriever/__init__.py +++ b/prompt2model/dataset_retriever/__init__.py @@ -1,6 +1,6 @@ """Import DatasetRetriever classes.""" from prompt2model.dataset_retriever.base import DatasetRetriever -from prompt2model.dataset_retriever.hf_dataset_retriever import ( +from prompt2model.dataset_retriever.description_dataset_retriever import ( DatasetInfo, DescriptionDatasetRetriever, ) diff --git a/prompt2model/dataset_retriever/base.py b/prompt2model/dataset_retriever/base.py index 4d2318123..9d205d585 100644 --- a/prompt2model/dataset_retriever/base.py +++ b/prompt2model/dataset_retriever/base.py @@ -2,6 +2,7 @@ from __future__ import annotations # noqa FI58 +import dataclasses from abc import ABC, abstractmethod import datasets @@ -9,6 +10,21 @@ from prompt2model.prompt_parser import PromptSpec +@dataclasses.dataclass +class DatasetInfo: + """Store the dataset name, description, and query-dataset score for each dataset. + + Args: + name: The name of the dataset. + description: The description of the dataset. + score: The retrieval score of the dataset. + """ + + name: str + description: str + score: float + + # pylint: disable=too-few-public-methods class DatasetRetriever(ABC): """A class for retrieving datasets.""" diff --git a/prompt2model/dataset_retriever/hf_dataset_retriever.py b/prompt2model/dataset_retriever/description_dataset_retriever.py similarity index 70% rename from prompt2model/dataset_retriever/hf_dataset_retriever.py rename to prompt2model/dataset_retriever/description_dataset_retriever.py index 9eaa194e6..63cd9dd7c 100644 --- a/prompt2model/dataset_retriever/hf_dataset_retriever.py +++ b/prompt2model/dataset_retriever/description_dataset_retriever.py @@ -3,51 +3,19 @@ from __future__ import annotations # noqa FI58 import json +import logging import os import urllib.request import datasets -import numpy as np import torch -from prompt2model.dataset_retriever.base import DatasetRetriever +from prompt2model.dataset_retriever.base import DatasetInfo, DatasetRetriever from prompt2model.prompt_parser import PromptSpec from prompt2model.utils import encode_text, retrieve_objects datasets.utils.logging.disable_progress_bar() - - -class DatasetInfo: - """Store the dataset name, description, and query-dataset score for each dataset.""" - - def __init__( - self, - name: str, - description: str, - score: float, - ): - """Initialize a DatasetInfo object. - - Args: - name: The name of the dataset. - description: The description of the dataset. - score: The similarity of the dataset to a given prompt from a user. - """ - self.name = name - self.description = description - self.score = score - - -def input_string(): - """Read a string from stdin.""" - description = str(input()) - return description - - -def input_y_n() -> bool: - """Get a yes/no answer from the user via stdin.""" - y_n = str(input()) - return not (y_n.strip() == "" or y_n.strip().lower() == "n") +logger = logging.getLogger(__name__) class DescriptionDatasetRetriever(DatasetRetriever): @@ -80,14 +48,21 @@ def __init__( self.first_stage_search_depth = first_stage_search_depth self.max_search_depth = max_search_depth self.encoder_model_name = encoder_model_name + self.device = device + self.dataset_info_file = dataset_info_file + self.initialize_search_index() + + def initialize_search_index(self) -> None: + """Initialize the search index.""" self.dataset_infos: list[DatasetInfo] = [] - if not os.path.exists(dataset_info_file): + if not os.path.exists(self.dataset_info_file): # Download the dataset search index if one is not on disk already. + logger.info("Downlidng the dataset search index") urllib.request.urlretrieve( "http://phontron.com/data/prompt2model/dataset_index.json", - dataset_info_file, + self.dataset_info_file, ) - self.full_dataset_metadata = json.load(open(dataset_info_file, "r")) + self.full_dataset_metadata = json.load(open(self.dataset_info_file, "r")) for dataset_name in sorted(self.full_dataset_metadata.keys()): self.dataset_infos.append( DatasetInfo( @@ -96,47 +71,59 @@ def __init__( score=0.0, ) ) - self.device = device - - assert not os.path.isdir( - search_index_path - ), f"Search index must either be a valid file or not exist yet. But {search_index_path} is provided." # noqa E501 - - def encode_dataset_descriptions( - self, dataset_infos, search_index_path - ) -> np.ndarray: - """Encode dataset descriptions into a vector for indexing.""" - dataset_descriptions = [ - dataset_info.description for dataset_info in dataset_infos - ] - dataset_vectors = encode_text( + if os.path.isdir(self.search_index_path): + raise ValueError( + "Search index must either be a valid file or not exist yet. " + "But {self.search_index_path} is provided." + ) + logger.info("Creating dataset descriptions") + encode_text( self.encoder_model_name, - text_to_encode=dataset_descriptions, - encoding_file=search_index_path, + text_to_encode=[x.description for x in self.dataset_infos], + encoding_file=self.search_index_path, device=self.device, ) - return dataset_vectors + # ---------------------------- Utility Functions ---------------------------- @staticmethod - def print_divider(): + def _input_string(): + """Utility function to read a string from stdin.""" + description = str(input()) + return description + + @staticmethod + def _input_y_n() -> bool: + """Utility function to get a yes/no answer from the user via stdin.""" + y_n = str(input()) + return not (y_n.strip() == "" or y_n.strip().lower() == "n") + + @staticmethod + def _print_divider(): """Utility function to assist with the retriever's command line interface.""" print("\n-------------------------------------------------\n") - def choose_dataset(self, top_datasets: list[DatasetInfo]) -> str | None: - """Have the user choose an appropriate dataset from a list of top datasets.""" - self.print_divider() + def choose_dataset_by_cli(self, top_datasets: list[DatasetInfo]) -> str | None: + """Have the user choose an appropriate dataset from a list of top datasets. + + Args: + top_datasets: A list of top datasets to choose from. + + Returns: + The name of the chosen dataset, or None if no dataset is chosen as relevant. + """ + self._print_divider() print("Here are the datasets I've retrieved for you:") print("#\tName\tDescription") for i, d in enumerate(top_datasets): description_no_spaces = d.description.replace("\n", " ") print(f"{i+1}):\t{d.name}\t{description_no_spaces}") - self.print_divider() + self._print_divider() print( "If none of these are relevant to your prompt, we'll only use " + "generated data. Are any of these datasets relevant? (y/N)" ) - any_datasets_relevant = input_y_n() + any_datasets_relevant = self._input_y_n() if any_datasets_relevant: print( "Which dataset would you like to use? Give the number between " @@ -146,7 +133,7 @@ def choose_dataset(self, top_datasets: list[DatasetInfo]) -> str | None: chosen_dataset_name = top_datasets[dataset_idx - 1].name else: chosen_dataset_name = None - self.print_divider() # noqa E501 + self._print_divider() return chosen_dataset_name @staticmethod @@ -159,11 +146,11 @@ def canonicalize_dataset_using_columns_for_split( input_col = [] output_col = [] for i in range(len(dataset_split)): - input_string = "" + curr_string = "" for col in input_columns: - input_string += f"{col}: {dataset_split[i][col]}\n" - input_string = input_string.strip() - input_col.append(input_string) + curr_string += f"{col}: {dataset_split[i][col]}\n" + curr_string = curr_string.strip() + input_col.append(curr_string) output_col.append(dataset_split[i][output_column]) return datasets.Dataset.from_dict( {"input_col": input_col, "output_col": output_col} @@ -183,18 +170,25 @@ def canonicalize_dataset_using_columns( ) return datasets.DatasetDict(dataset_dict) - def canonicalize_dataset(self, dataset_name: str) -> datasets.DatasetDict: - """Canonicalize a dataset into a suitable text-to-text format.""" + def canonicalize_dataset_by_cli(self, dataset_name: str) -> datasets.DatasetDict: + """Canonicalize a dataset into a suitable text-to-text format. + + Args: + dataset_name: The name of the dataset to canonicalize. + + Returns: + A canonicalized dataset. + """ configs = datasets.get_dataset_config_names(dataset_name) chosen_config = None if len(configs) == 1: chosen_config = configs[0] else: - self.print_divider() + self._print_divider() print(f"Multiple dataset configs available: {configs}") while chosen_config is None: print("Which dataset config would you like to use for this?") - user_response = input_string() + user_response = self._input_string() if user_response in configs: chosen_config = user_response else: @@ -202,7 +196,7 @@ def canonicalize_dataset(self, dataset_name: str) -> datasets.DatasetDict: f"Invalid config provided: {user_response}. Please choose " + "from {configs}\n\n" ) - self.print_divider() + self._print_divider() dataset = datasets.load_dataset(dataset_name, chosen_config) assert "train" in dataset @@ -212,14 +206,14 @@ def canonicalize_dataset(self, dataset_name: str) -> datasets.DatasetDict: assert len(dataset["train"]) > 0 example_rows = json.dumps(dataset["train"][0], indent=4) - self.print_divider() + self._print_divider() print(f"Loaded dataset. Example row:\n{example_rows}\n") print( "Which column(s) should we use as input? Provide a comma-separated " + f"list from: {train_columns_formatted}." ) - user_response = input_string() + user_response = self._input_string() input_columns = [c.strip() for c in user_response.split(",")] print(f"Will use the columns {json.dumps(input_columns)} as input.\n") @@ -229,7 +223,7 @@ def canonicalize_dataset(self, dataset_name: str) -> datasets.DatasetDict: "Which column(s) should we use as the target? Choose a single " + f"value from: {train_columns_formatted}." ) - user_response = input_string() + user_response = self._input_string() if user_response in train_columns: output_column = user_response else: @@ -238,41 +232,37 @@ def canonicalize_dataset(self, dataset_name: str) -> datasets.DatasetDict: + f"from {train_columns}\n\n" ) print(f'Will use the column "{output_column}" as our target.\n') - self.print_divider() + self._print_divider() canonicalized_dataset = self.canonicalize_dataset_using_columns( dataset, input_columns, output_column ) return canonicalized_dataset - def retrieve_dataset_dict( + def retrieve_top_datasets( self, prompt_spec: PromptSpec, - ) -> datasets.DatasetDict | None: - """Select a dataset from a prompt using a dual-encoder retriever. + ) -> list[DatasetInfo]: + """Retrieve the top datasets for a prompt. + + Specifically, the datasets are scored using a dual-encoder retriever model + and the datasets with the highest similarity scores with the query are returned. Args: prompt_spec: A prompt whose instruction field we use to retrieve datasets. - Return: - A list of relevant datasets dictionaries. + Returns: + A list of the top datasets for the prompt according to retriever score. """ - if not os.path.exists(self.search_index_path): - print("Creating dataset descriptions") - self.encode_dataset_descriptions(self.dataset_infos, self.search_index_path) - - query_text = prompt_spec.instruction - query_vector = encode_text( self.encoder_model_name, - text_to_encode=query_text, + text_to_encode=prompt_spec.instruction, device=self.device, ) - dataset_names = [dataset_info.name for dataset_info in self.dataset_infos] ranked_list = retrieve_objects( query_vector, self.search_index_path, - dataset_names, + [x.name for x in self.dataset_infos], self.first_stage_search_depth, ) top_dataset_infos = [] @@ -287,8 +277,24 @@ def retrieve_dataset_dict( sorted_list = sorted(top_dataset_infos, key=lambda x: x.score, reverse=True)[ : self.max_search_depth ] - assert len(sorted_list) > 0, "No datasets retrieved from search index." - top_dataset_name = self.choose_dataset(sorted_list) + if len(sorted_list) == 0: + raise ValueError("No datasets retrieved from search index.") + return sorted_list + + def retrieve_dataset_dict( + self, + prompt_spec: PromptSpec, + ) -> datasets.DatasetDict | None: + """Select a dataset from a prompt using a dual-encoder retriever. + + Args: + prompt_spec: A prompt whose instruction field we use to retrieve datasets. + + Return: + A list of relevant datasets dictionaries. + """ + sorted_list = self.retrieve_top_datasets(prompt_spec) + top_dataset_name = self.choose_dataset_by_cli(sorted_list) if top_dataset_name is None: return None - return self.canonicalize_dataset(top_dataset_name) + return self.canonicalize_dataset_by_cli(top_dataset_name) diff --git a/tests/dataset_retriever_test.py b/tests/dataset_retriever_test.py index 6afe36850..dc2220f53 100644 --- a/tests/dataset_retriever_test.py +++ b/tests/dataset_retriever_test.py @@ -3,6 +3,7 @@ from __future__ import annotations # noqa FI58 import os +import pickle import tempfile from unittest.mock import patch @@ -32,17 +33,18 @@ def test_initialize_dataset_retriever(): def test_encode_model_retriever(): """Test loading a small Tevatron model.""" - with tempfile.NamedTemporaryFile(mode="w", suffix=".pkl") as f: + with tempfile.TemporaryDirectory() as tempdir: + temporary_file = os.path.join(tempdir, "search_index.pkl") retriever = DescriptionDatasetRetriever( - search_index_path=f.name, + search_index_path=temporary_file, first_stage_search_depth=3, max_search_depth=3, encoder_model_name=TINY_DUAL_ENCODER_NAME, dataset_info_file="test_helpers/dataset_index_tiny.json", ) - model_vectors = retriever.encode_dataset_descriptions( - retriever.dataset_infos, f.name - ) + retriever.initialize_search_index() + with open(temporary_file, "rb") as f: + model_vectors, _ = pickle.load(f) assert model_vectors.shape == (3, 128) @@ -119,17 +121,17 @@ def mock_canonicalize_dataset(self, dataset_name: str) -> DatasetDict: @patch( - "prompt2model.dataset_retriever.hf_dataset_retriever.encode_text", + "prompt2model.dataset_retriever.description_dataset_retriever.encode_text", return_value=np.array([[1, 0, 0]]), ) @patch.object( DescriptionDatasetRetriever, - "choose_dataset", + "choose_dataset_by_cli", new=mock_choose_dataset, ) @patch.object( DescriptionDatasetRetriever, - "canonicalize_dataset", + "canonicalize_dataset_by_cli", new=mock_canonicalize_dataset, ) def test_retrieve_dataset_dict_when_search_index_exists(encode_text): @@ -151,7 +153,7 @@ def test_retrieve_dataset_dict_when_search_index_exists(encode_text): mock_prompt = MockPromptSpec(task_type=TaskType.TEXT_GENERATION) retrieved_dataset = retriever.retrieve_dataset_dict(mock_prompt) - assert encode_text.call_count == 1 + assert encode_text.call_count == 2 for split_name in ["train", "val", "test"]: assert split_name in retrieved_dataset split = retrieved_dataset[split_name] @@ -163,25 +165,18 @@ def test_retrieve_dataset_dict_when_search_index_exists(encode_text): assert split[0]["output_col"] == "mammals" -@patch.object( - DescriptionDatasetRetriever, - "encode_dataset_descriptions", - new=lambda self, dataset_infos, index_file_name: create_test_search_index( - index_file_name - ), -) @patch( - "prompt2model.dataset_retriever.hf_dataset_retriever.encode_text", + "prompt2model.dataset_retriever.description_dataset_retriever.encode_text", return_value=np.array([[1, 0, 0]]), ) @patch.object( DescriptionDatasetRetriever, - "choose_dataset", + "choose_dataset_by_cli", new=mock_choose_dataset, ) @patch.object( DescriptionDatasetRetriever, - "canonicalize_dataset", + "canonicalize_dataset_by_cli", new=mock_canonicalize_dataset, ) def test_retrieve_dataset_dict_without_existing_search_index(encode_text): @@ -195,6 +190,7 @@ def test_retrieve_dataset_dict_without_existing_search_index(encode_text): encoder_model_name=TINY_DUAL_ENCODER_NAME, dataset_info_file="test_helpers/dataset_index_tiny.json", ) + create_test_search_index(temporary_file) assert [info.name for info in retriever.dataset_infos] == [ "search_qa", "squad", @@ -202,7 +198,7 @@ def test_retrieve_dataset_dict_without_existing_search_index(encode_text): ] mock_prompt = MockPromptSpec(task_type=TaskType.TEXT_GENERATION) retrieved_dataset = retriever.retrieve_dataset_dict(mock_prompt) - assert encode_text.call_count == 1 + assert encode_text.call_count == 2 for split_name in ["train", "val", "test"]: assert split_name in retrieved_dataset split = retrieved_dataset[split_name]