Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor dataset retriever to isolate CLI functions #294

Merged
merged 3 commits into from
Aug 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion prompt2model/dataset_retriever/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Import DatasetRetriever classes."""
from prompt2model.dataset_retriever.base import DatasetRetriever
from prompt2model.dataset_retriever.hf_dataset_retriever import (
from prompt2model.dataset_retriever.description_dataset_retriever import (
DatasetInfo,
DescriptionDatasetRetriever,
)
Expand Down
16 changes: 16 additions & 0 deletions prompt2model/dataset_retriever/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,29 @@

from __future__ import annotations # noqa FI58

import dataclasses
from abc import ABC, abstractmethod

import datasets

from prompt2model.prompt_parser import PromptSpec


@dataclasses.dataclass
class DatasetInfo:
"""Store the dataset name, description, and query-dataset score for each dataset.

Args:
name: The name of the dataset.
description: The description of the dataset.
score: The retrieval score of the dataset.
"""

name: str
description: str
score: float


# pylint: disable=too-few-public-methods
class DatasetRetriever(ABC):
"""A class for retrieving datasets."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,51 +3,19 @@
from __future__ import annotations # noqa FI58

import json
import logging
import os
import urllib.request

import datasets
import numpy as np
import torch

from prompt2model.dataset_retriever.base import DatasetRetriever
from prompt2model.dataset_retriever.base import DatasetInfo, DatasetRetriever
from prompt2model.prompt_parser import PromptSpec
from prompt2model.utils import encode_text, retrieve_objects

datasets.utils.logging.disable_progress_bar()


class DatasetInfo:
"""Store the dataset name, description, and query-dataset score for each dataset."""

def __init__(
self,
name: str,
description: str,
score: float,
):
"""Initialize a DatasetInfo object.

Args:
name: The name of the dataset.
description: The description of the dataset.
score: The similarity of the dataset to a given prompt from a user.
"""
self.name = name
self.description = description
self.score = score


def input_string():
"""Read a string from stdin."""
description = str(input())
return description


def input_y_n() -> bool:
"""Get a yes/no answer from the user via stdin."""
y_n = str(input())
return not (y_n.strip() == "" or y_n.strip().lower() == "n")
logger = logging.getLogger(__name__)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I should change this in all the other components. 😅



class DescriptionDatasetRetriever(DatasetRetriever):
Expand Down Expand Up @@ -80,14 +48,21 @@ def __init__(
self.first_stage_search_depth = first_stage_search_depth
self.max_search_depth = max_search_depth
self.encoder_model_name = encoder_model_name
self.device = device
self.dataset_info_file = dataset_info_file
self.initialize_search_index()

def initialize_search_index(self) -> None:
"""Initialize the search index."""
self.dataset_infos: list[DatasetInfo] = []
if not os.path.exists(dataset_info_file):
if not os.path.exists(self.dataset_info_file):
# Download the dataset search index if one is not on disk already.
logger.info("Downlidng the dataset search index")
urllib.request.urlretrieve(
"http://phontron.com/data/prompt2model/dataset_index.json",
dataset_info_file,
self.dataset_info_file,
)
self.full_dataset_metadata = json.load(open(dataset_info_file, "r"))
self.full_dataset_metadata = json.load(open(self.dataset_info_file, "r"))
for dataset_name in sorted(self.full_dataset_metadata.keys()):
self.dataset_infos.append(
DatasetInfo(
Expand All @@ -96,47 +71,59 @@ def __init__(
score=0.0,
)
)
self.device = device

assert not os.path.isdir(
search_index_path
), f"Search index must either be a valid file or not exist yet. But {search_index_path} is provided." # noqa E501

def encode_dataset_descriptions(
self, dataset_infos, search_index_path
) -> np.ndarray:
"""Encode dataset descriptions into a vector for indexing."""
dataset_descriptions = [
dataset_info.description for dataset_info in dataset_infos
]
dataset_vectors = encode_text(
if os.path.isdir(self.search_index_path):
raise ValueError(
"Search index must either be a valid file or not exist yet. "
"But {self.search_index_path} is provided."
)
logger.info("Creating dataset descriptions")
encode_text(
self.encoder_model_name,
text_to_encode=dataset_descriptions,
encoding_file=search_index_path,
text_to_encode=[x.description for x in self.dataset_infos],
encoding_file=self.search_index_path,
device=self.device,
)
return dataset_vectors

# ---------------------------- Utility Functions ----------------------------
@staticmethod
def print_divider():
def _input_string():
"""Utility function to read a string from stdin."""
description = str(input())
return description

@staticmethod
def _input_y_n() -> bool:
"""Utility function to get a yes/no answer from the user via stdin."""
y_n = str(input())
return not (y_n.strip() == "" or y_n.strip().lower() == "n")

@staticmethod
def _print_divider():
"""Utility function to assist with the retriever's command line interface."""
print("\n-------------------------------------------------\n")

def choose_dataset(self, top_datasets: list[DatasetInfo]) -> str | None:
"""Have the user choose an appropriate dataset from a list of top datasets."""
self.print_divider()
def choose_dataset_by_cli(self, top_datasets: list[DatasetInfo]) -> str | None:
"""Have the user choose an appropriate dataset from a list of top datasets.

Args:
top_datasets: A list of top datasets to choose from.

Returns:
The name of the chosen dataset, or None if no dataset is chosen as relevant.
"""
self._print_divider()
print("Here are the datasets I've retrieved for you:")
print("#\tName\tDescription")
for i, d in enumerate(top_datasets):
description_no_spaces = d.description.replace("\n", " ")
print(f"{i+1}):\t{d.name}\t{description_no_spaces}")

self.print_divider()
self._print_divider()
print(
"If none of these are relevant to your prompt, we'll only use "
+ "generated data. Are any of these datasets relevant? (y/N)"
)
any_datasets_relevant = input_y_n()
any_datasets_relevant = self._input_y_n()
if any_datasets_relevant:
print(
"Which dataset would you like to use? Give the number between "
Expand All @@ -146,7 +133,7 @@ def choose_dataset(self, top_datasets: list[DatasetInfo]) -> str | None:
chosen_dataset_name = top_datasets[dataset_idx - 1].name
else:
chosen_dataset_name = None
self.print_divider() # noqa E501
self._print_divider()
return chosen_dataset_name

@staticmethod
Expand All @@ -159,11 +146,11 @@ def canonicalize_dataset_using_columns_for_split(
input_col = []
output_col = []
for i in range(len(dataset_split)):
input_string = ""
curr_string = ""
for col in input_columns:
input_string += f"{col}: {dataset_split[i][col]}\n"
input_string = input_string.strip()
input_col.append(input_string)
curr_string += f"{col}: {dataset_split[i][col]}\n"
curr_string = curr_string.strip()
input_col.append(curr_string)
output_col.append(dataset_split[i][output_column])
return datasets.Dataset.from_dict(
{"input_col": input_col, "output_col": output_col}
Expand All @@ -183,26 +170,33 @@ def canonicalize_dataset_using_columns(
)
return datasets.DatasetDict(dataset_dict)

def canonicalize_dataset(self, dataset_name: str) -> datasets.DatasetDict:
"""Canonicalize a dataset into a suitable text-to-text format."""
def canonicalize_dataset_by_cli(self, dataset_name: str) -> datasets.DatasetDict:
"""Canonicalize a dataset into a suitable text-to-text format.

Args:
dataset_name: The name of the dataset to canonicalize.

Returns:
A canonicalized dataset.
"""
configs = datasets.get_dataset_config_names(dataset_name)
chosen_config = None
if len(configs) == 1:
chosen_config = configs[0]
else:
self.print_divider()
self._print_divider()
print(f"Multiple dataset configs available: {configs}")
while chosen_config is None:
print("Which dataset config would you like to use for this?")
user_response = input_string()
user_response = self._input_string()
if user_response in configs:
chosen_config = user_response
else:
print(
f"Invalid config provided: {user_response}. Please choose "
+ "from {configs}\n\n"
)
self.print_divider()
self._print_divider()

dataset = datasets.load_dataset(dataset_name, chosen_config)
assert "train" in dataset
Expand All @@ -212,14 +206,14 @@ def canonicalize_dataset(self, dataset_name: str) -> datasets.DatasetDict:
assert len(dataset["train"]) > 0
example_rows = json.dumps(dataset["train"][0], indent=4)

self.print_divider()
self._print_divider()
print(f"Loaded dataset. Example row:\n{example_rows}\n")

print(
"Which column(s) should we use as input? Provide a comma-separated "
+ f"list from: {train_columns_formatted}."
)
user_response = input_string()
user_response = self._input_string()
input_columns = [c.strip() for c in user_response.split(",")]
print(f"Will use the columns {json.dumps(input_columns)} as input.\n")

Expand All @@ -229,7 +223,7 @@ def canonicalize_dataset(self, dataset_name: str) -> datasets.DatasetDict:
"Which column(s) should we use as the target? Choose a single "
+ f"value from: {train_columns_formatted}."
)
user_response = input_string()
user_response = self._input_string()
if user_response in train_columns:
output_column = user_response
else:
Expand All @@ -238,41 +232,37 @@ def canonicalize_dataset(self, dataset_name: str) -> datasets.DatasetDict:
+ f"from {train_columns}\n\n"
)
print(f'Will use the column "{output_column}" as our target.\n')
self.print_divider()
self._print_divider()

canonicalized_dataset = self.canonicalize_dataset_using_columns(
dataset, input_columns, output_column
)
return canonicalized_dataset

def retrieve_dataset_dict(
def retrieve_top_datasets(
self,
prompt_spec: PromptSpec,
) -> datasets.DatasetDict | None:
"""Select a dataset from a prompt using a dual-encoder retriever.
) -> list[DatasetInfo]:
"""Retrieve the top datasets for a prompt.

Specifically, the datasets are scored using a dual-encoder retriever model
and the datasets with the highest similarity scores with the query are returned.

Args:
prompt_spec: A prompt whose instruction field we use to retrieve datasets.

Return:
A list of relevant datasets dictionaries.
Returns:
A list of the top datasets for the prompt according to retriever score.
"""
if not os.path.exists(self.search_index_path):
print("Creating dataset descriptions")
self.encode_dataset_descriptions(self.dataset_infos, self.search_index_path)

query_text = prompt_spec.instruction

query_vector = encode_text(
self.encoder_model_name,
text_to_encode=query_text,
text_to_encode=prompt_spec.instruction,
device=self.device,
)
dataset_names = [dataset_info.name for dataset_info in self.dataset_infos]
ranked_list = retrieve_objects(
query_vector,
self.search_index_path,
dataset_names,
[x.name for x in self.dataset_infos],
self.first_stage_search_depth,
)
top_dataset_infos = []
Expand All @@ -287,8 +277,24 @@ def retrieve_dataset_dict(
sorted_list = sorted(top_dataset_infos, key=lambda x: x.score, reverse=True)[
: self.max_search_depth
]
assert len(sorted_list) > 0, "No datasets retrieved from search index."
top_dataset_name = self.choose_dataset(sorted_list)
if len(sorted_list) == 0:
raise ValueError("No datasets retrieved from search index.")
return sorted_list

def retrieve_dataset_dict(
self,
prompt_spec: PromptSpec,
) -> datasets.DatasetDict | None:
"""Select a dataset from a prompt using a dual-encoder retriever.

Args:
prompt_spec: A prompt whose instruction field we use to retrieve datasets.

Return:
A list of relevant datasets dictionaries.
"""
sorted_list = self.retrieve_top_datasets(prompt_spec)
top_dataset_name = self.choose_dataset_by_cli(sorted_list)
if top_dataset_name is None:
return None
return self.canonicalize_dataset(top_dataset_name)
return self.canonicalize_dataset_by_cli(top_dataset_name)
Loading