Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RAFT Recovery Mode for interruptions #410

Merged
merged 7 commits into from
May 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions raft/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,14 @@ Arguments:
- `--openai_key` - your OpenAI key used to make queries to GPT-3.5 or GPT-4
- `--embedding-model` - The embedding model to use to encode documents chunks. Defaults to `text-embedding-ada-002`.
- `--completion-model` - The model to use to generate questions and answers. Defaults to `gpt-4`.
- `--fast` - Fast mode flag. By default, this flag is not included and the script runs in safe mode, where it saves checkpoint datasets, allowing the script to recover and continue where it left off in the case of an interruption. Include this flag to run RAFT without recovery.


## Usage with OpenAI API

Run the following command with your desired arguments to generate the dataset.
```bash
python3 raft.py --datapath PATH_TO_DATA --output OUTPUT_PATH --distractors 3 --doctype pdf --chunk_size 512 --questions 5 --openai_key YOUR_OPENAI_KEY
python3 raft.py --datapath PATH_TO_DATA --output OUTPUT_PATH --distractors 3 --p 1.0 --doctype pdf --chunk_size 512 --questions 5 --openai_key YOUR_OPENAI_KEY
```

**Note**: As an alternative to passing the OpenAI key with the `--openai_key` argument, you also store the standard OpenAI environment variables in a file called `.env` like so. All standard OpenAI env variables are supported.
Expand All @@ -49,7 +50,7 @@ OPENAI_API_KEY=<replace_me>

## Usage with Azure OpenAI API

Create a file `.env` like so. All standard Azure OpenAI environement variables are supported.
Create a file `.env` like so. All standard Azure OpenAI environment variables are supported.

```
# Azure OpenAI API
Expand Down
69 changes: 61 additions & 8 deletions raft/raft.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
from typing import Literal, Any
import argparse
from openai import OpenAI
import datasets
from datasets import Dataset, load_dataset
from transformers import AutoTokenizer
import json
import PyPDF2
import random
import os, shutil
from langchain_experimental.text_splitter import SemanticChunker
from langchain_openai.embeddings import OpenAIEmbeddings
from client_utils import build_openai_client, build_langchain_embeddings
Expand All @@ -21,6 +23,9 @@

DocType = Literal["api", "pdf", "json", "txt"]

# Every N chunks, save checkpoint
N = 15

def get_args() -> argparse.Namespace:
"""
Parses and returns the arguments specified by the user's command
Expand All @@ -35,8 +40,9 @@ def get_args() -> argparse.Namespace:
parser.add_argument("--chunk_size", type=int, default=512, help="The size of each chunk in number of tokens")
parser.add_argument("--doctype", type=str, default="pdf", help="The type of the document, must be one of the accepted doctypes", choices=["pdf", "txt", "json", "api"])
parser.add_argument("--openai_key", type=str, default=None, help="Your OpenAI key used to make queries to GPT-3.5 or GPT-4")
parser.add_argument("--embedding-model", type=str, default="text-embedding-ada-002", help="The embedding model to use to encode documents chunks (text-embedding-ada-002, ...)")
parser.add_argument("--completion-model", type=str, default="gpt-4", help="The model to use to generate questions and answers (gpt-3.5, gpt-4, ...)")
parser.add_argument("--embedding_model", type=str, default="text-embedding-ada-002", help="The embedding model to use to encode documents chunks (text-embedding-ada-002, ...)")
parser.add_argument("--completion_model", type=str, default="gpt-4", help="The model to use to generate questions and answers (gpt-3.5, gpt-4, ...)")
parser.add_argument("--fast", action="store_true", help="Run the script in fast mode (no recovery implemented)")

args = parser.parse_args()
return args
Expand Down Expand Up @@ -273,6 +279,14 @@ def add_chunk_to_dataset(
else:
ds = ds.add_item(datapt)

def save_checkpoint(state, filename):
with open(filename, 'w') as f:
f.write(str(state))

def load_checkpoint(filename):
with open(filename, 'r') as f:
return eval(f.read())

def main():
global ds

Expand All @@ -293,18 +307,57 @@ def main():
ds = None

num_chunks = len(chunks)
for i, chunk in enumerate(chunks):
perc = ceil(i / num_chunks * 100)
with MDC(progress=f"{perc}%"):
logger.info(f"Adding chunk {i}/{num_chunks}")
add_chunk_to_dataset(client, chunks, chunk, args.doctype, args.questions, NUM_DISTRACT_DOCS, model=args.completion_model)

if not args.fast:
start = 0
if os.path.exists("checkpoint.txt"):
start = int(load_checkpoint("checkpoint.txt"))

for i in range((start//N)*N, len(chunks)):
chunk = chunks[i]
save_checkpoint(i, "checkpoint.txt")

perc = ceil(i / num_chunks * 100)
with MDC(progress=f"{perc}%"):
logger.info(f"Adding chunk {i}/{num_chunks}")
add_chunk_to_dataset(client, chunks, chunk, args.doctype, args.questions, NUM_DISTRACT_DOCS, model=args.completion_model)

if (i+1) % N == 0:
ds.save_to_disk(args.output + "-checkpoints-" + str(i))
ds = None


if ds:
ds.save_to_disk(args.output + "-checkpoints-last")

ds_list = []

for filename in os.listdir(os.path.dirname(args.output)):
if "-checkpoints-" in filename:
for f in os.listdir(os.path.dirname(args.output) + "/" + filename):
if f.endswith(".arrow"):
ds_list.append(Dataset.from_file(os.path.dirname(args.output) + "/" + filename + "/" + f))

ds = datasets.concatenate_datasets(ds_list)
else:
for i, chunk in enumerate(chunks):
perc = ceil(i / num_chunks * 100)
with MDC(progress=f"{perc}%"):
logger.info(f"Adding chunk {i}/{num_chunks}")
add_chunk_to_dataset(client, chunks, chunk, args.doctype, args.questions, NUM_DISTRACT_DOCS, model=args.completion_model)

# Save as .arrow format
ds.save_to_disk(args.output)

# Save as .jsonl format
ds.to_json(args.output + ".jsonl")

if not args.fast:
os.remove("checkpoint.txt")
for filename in os.listdir(os.path.dirname(args.output)):
if "-checkpoints-" in filename:
shutil.rmtree(os.path.dirname(args.output) + "/" + filename)

if __name__ == "__main__":
with MDC(progress="0%"):
main()
Binary file added raft/sample_data/UC_Berkeley_short.pdf
Binary file not shown.