forked from ShishirPatil/gorilla
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
RAFT Support for chat and completion model formats (ShishirPatil#417)
Adds support to convert the dataset to formats expected to fine tune `completion` and `chat` models as specified there: https://platform.openai.com/docs/guides/fine-tuning/preparing-your-dataset `chat` format: ``` {"messages": [{"role": "system", "content": "Marv is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "What's the capital of France?"}, {"role": "assistant", "content": "Paris, as if everyone doesn't know that already."}]} {"messages": [{"role": "system", "content": "Marv is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "Who wrote 'Romeo and Juliet'?"}, {"role": "assistant", "content": "Oh, just some guy named William Shakespeare. Ever heard of him?"}]} {"messages": [{"role": "system", "content": "Marv is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "How far is the Moon from Earth?"}, {"role": "assistant", "content": "Around 384,400 kilometers. Give or take a few, like that really matters."}]} ``` `completion` format: ``` {"prompt": "<prompt text>", "completion": "<ideal generated text>"} {"prompt": "<prompt text>", "completion": "<ideal generated text>"} {"prompt": "<prompt text>", "completion": "<ideal generated text>"} ``` `raft.py`: - Supports `jsonl` and `parquet` output types - Supports `hf`, `chat` and `completion` formats - `chat` format also accepts a `--output-chat-system-prompt` param to configure the system prompt - Ignore venv folders - Added usage to --help ``` --output-format {hf,completion,chat} Format to convert the dataset to. Defaults to hf. --output-type {parquet,jsonl} Type to export the dataset to. Defaults to jsonl. --output-chat-system-prompt OUTPUT_CHAT_SYSTEM_PROMPT The system prompt to use when the output format is chat ``` New `format.py` script, to convert dataset previously generated by `raft.py`: ``` $ python format.py --help usage: format.py [-h] --input INPUT [--input-type {arrow,jsonl}] --output OUTPUT --output-format {hf,completion,chat} [--output-type {parquet,jsonl}] [--output-chat-system-prompt OUTPUT_CHAT_SYSTEM_PROMPT] options: -h, --help show this help message and exit --input INPUT Input HuggingFace dataset file --input-type {arrow,jsonl} Format of the input dataset. Defaults to arrow. --output OUTPUT Output file --output-format {hf,completion,chat} Format to convert the dataset to --output-type {parquet,jsonl} Type to export the dataset to. Defaults to jsonl. --output-chat-system-prompt OUTPUT_CHAT_SYSTEM_PROMPT The system prompt to use when the output format is chat ``` How to test `format.py` with the `chat` format: ``` python format.py --input output/data-00000-of-00001.arrow \ --output output/ucb-short.chat.jsonl \ --output-format chat \ --output-chat-system-prompt 'You are an AI expert on UC Berkeley' ``` How to test `format.py` with the `completion` format: ``` python format.py --input output/data-00000-of-00001.arrow \ --output output/ucb-short.completion.jsonl \ --output-format completion ``` How to test `raft.py` with the `chat` format: ``` python3 raft.py \ --datapath $PWD/sample_data/UC_Berkeley_short.pdf \ --output $PWD/output \ --distractors 3 \ --doctype pdf \ --chunk_size 512 \ --questions 2 \ --completion_model gpt-4-turbo \ --embedding_model text-embedding-ada-002 \ --output-format chat \ --output-chat-system-prompt "You're a RAG AI" ``` Co-authored-by: Shishir Patil <[email protected]>
- Loading branch information
1 parent
42f9d28
commit ae5f0a2
Showing
4 changed files
with
238 additions
and
17 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,173 @@ | ||
from abc import ABC, abstractmethod | ||
import argparse | ||
from datasets import Dataset, load_dataset | ||
from typing import Dict, Literal, Any, get_args | ||
|
||
""" | ||
This file allows to convert raw HuggingFace Datasets into files suitable to fine tune completion and chat models. | ||
""" | ||
|
||
OutputDatasetType = Literal["parquet", "jsonl"] | ||
outputDatasetTypes = list(get_args(OutputDatasetType)) | ||
|
||
InputDatasetType = Literal["arrow", "jsonl"] | ||
inputDatasetTypes = list(get_args(InputDatasetType)) | ||
|
||
DatasetFormat = Literal["hf", "completion", "chat"] | ||
datasetFormats = list(get_args(DatasetFormat)) | ||
|
||
def get_args() -> argparse.Namespace: | ||
""" | ||
Parses and returns the arguments specified by the user's command | ||
""" | ||
parser = argparse.ArgumentParser() | ||
|
||
parser.add_argument("--input", type=str, required=True, help="Input HuggingFace dataset file") | ||
parser.add_argument("--input-type", type=str, default="arrow", help="Format of the input dataset. Defaults to arrow.", choices=inputDatasetTypes) | ||
parser.add_argument("--output", type=str, required=True, help="Output file") | ||
parser.add_argument("--output-format", type=str, required=True, help="Format to convert the dataset to", choices=datasetFormats) | ||
parser.add_argument("--output-type", type=str, default="jsonl", help="Type to export the dataset to. Defaults to jsonl.", choices=outputDatasetTypes) | ||
parser.add_argument("--output-chat-system-prompt", type=str, help="The system prompt to use when the output format is chat") | ||
|
||
args = parser.parse_args() | ||
return args | ||
|
||
class DatasetFormatter(ABC): | ||
""" | ||
Base class for dataset formatters. Formatters rename columns, remove and add | ||
columns to match the expected target format structure. HF, Chat or Completion models file formats. | ||
https://platform.openai.com/docs/guides/fine-tuning/preparing-your-dataset | ||
""" | ||
@abstractmethod | ||
def format(self, ds: Dataset, params: Dict[str, str]) -> Dataset: | ||
pass | ||
|
||
class DatasetExporter(ABC): | ||
""" | ||
Base class for dataset exporters. Exporters export dataset to different file types, JSONL, Parquet, ... | ||
""" | ||
@abstractmethod | ||
def export(self, ds: Dataset, output_path: str): | ||
pass | ||
|
||
class DatasetConverter(): | ||
""" | ||
Entry point class. It resolves which DatasetFormatter and which DatasetExporter to use and runs them. | ||
""" | ||
formats: Dict[DatasetFormat, DatasetFormatter] | ||
exporters: Dict[OutputDatasetType, Any] | ||
|
||
def __init__(self) -> None: | ||
self.formats = { | ||
"hf": HuggingFaceDatasetFormatter(), | ||
"completion": OpenAiCompletionDatasetFormatter(), | ||
"chat": OpenAiChatDatasetFormatter() | ||
} | ||
self.exporters = { | ||
"parquet": ParquetDatasetExporter(), | ||
"jsonl": JsonlDatasetExporter() | ||
} | ||
|
||
def convert(self, ds: Dataset, format: DatasetFormat, output_path: str, output_type: OutputDatasetType, params: Dict[str, str]): | ||
if not format in self.formats: | ||
raise Exception(f"Output Format {format} is not supported, pleased select one of {self.formats.keys()}") | ||
|
||
if not output_type in self.exporters: | ||
raise Exception(f"Output Type {output_type} is not supported, pleased select one of {self.exporters.keys()}") | ||
|
||
formatter = self.formats[format] | ||
newds = formatter.format(ds, params) | ||
exporter = self.exporters[output_type] | ||
exporter.export(newds, output_path) | ||
|
||
class HuggingFaceDatasetFormatter(DatasetFormatter): | ||
""" | ||
Returns the HuggingFace Dataset as is | ||
""" | ||
def format(self, ds: Dataset, params: Dict[str, str]) -> Dataset: | ||
return ds | ||
|
||
def _remove_all_columns_but(ds: Dataset, keep_columns) -> Dataset: | ||
""" | ||
HF Dataset doesn't have a way to copy only specific columns of a Dataset so this help | ||
removes all columns but the ones specified. | ||
""" | ||
remove_columns = list(ds.column_names) | ||
for keep in keep_columns: | ||
remove_columns.remove(keep) | ||
ds = ds.remove_columns(remove_columns) | ||
return ds | ||
|
||
class OpenAiCompletionDatasetFormatter(DatasetFormatter): | ||
""" | ||
Returns the Dataset in the OpenAI Completion Fine-tuning file format with two fields "prompt" and "completion". | ||
https://platform.openai.com/docs/guides/fine-tuning/preparing-your-dataset | ||
""" | ||
def format(self, ds: Dataset, params: Dict[str, str]) -> Dataset: | ||
newds = ds.rename_columns({'question': 'prompt', 'cot_answer': 'completion'}) | ||
return _remove_all_columns_but(newds, ['prompt', 'completion']) | ||
|
||
class OpenAiChatDatasetFormatter(OpenAiCompletionDatasetFormatter): | ||
""" | ||
Returns the Dataset in the OpenAI Chat Fine-tuning file format with one field "messages". | ||
https://platform.openai.com/docs/guides/fine-tuning/preparing-your-dataset | ||
""" | ||
def format(self, ds: Dataset, params: Dict[str, str]) -> Dataset: | ||
newds = super().format(ds, params) | ||
|
||
def format_messages(row): | ||
messages = [] | ||
if 'system_prompt' in params: | ||
system_prompt = params['system_prompt'] | ||
messages.append({ "role": "system", "content": system_prompt}) | ||
messages.extend([{ "role": "user", "content": row['prompt']}, { "role": "assistant", "content": row['completion']}]) | ||
chat_row = {"messages": messages} | ||
return chat_row | ||
|
||
newds = newds.map(format_messages) | ||
return _remove_all_columns_but(newds, ['messages']) | ||
|
||
def append_extension(path: str, extension: str) -> str: | ||
suffix = "." + extension | ||
if not path.endswith(suffix): | ||
path = path + suffix | ||
return path | ||
|
||
|
||
class JsonlDatasetExporter(DatasetExporter): | ||
""" | ||
Exports the Dataset to a JSONL file | ||
""" | ||
|
||
def export(self, ds: Dataset, output_path: str): | ||
ds.to_json(append_extension(output_path, "jsonl")) | ||
|
||
|
||
class ParquetDatasetExporter(DatasetExporter): | ||
""" | ||
Exports the Dataset to a Parquet file | ||
""" | ||
|
||
def export(self, ds: Dataset, output_path: str): | ||
ds.to_parquet(append_extension(output_path, "parquet")) | ||
|
||
|
||
def main(): | ||
""" | ||
When raft.py is executed from the command line. | ||
""" | ||
args = get_args() | ||
ds = load_dataset(args.input_type, data_files={"train": args.input})['train'] | ||
formatter = DatasetConverter() | ||
|
||
if args.output_chat_system_prompt and args.output_format != "chat": | ||
raise Exception("Parameter --output-chat-system-prompt can only be used with --output-format chat") | ||
|
||
format_params = {} | ||
if args.output_chat_system_prompt: | ||
format_params['system_prompt'] = args.output_chat_system_prompt | ||
|
||
formatter.convert(ds=ds, format=args.output_format, output_path=args.output, output_type=args.output_type, params=format_params) | ||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters