Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RAFT Enhancements: Improved robustness, logging, checkpointing, threading, Llama support, Azure auth and eval #604

Merged
merged 81 commits into from
Aug 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
81 commits
Select commit Hold shift + click to select a range
44e2ca7
Resolve default logging conf file relative to logconf.py
cedricvidal May 10, 2024
bc24be6
Format support for eval format
cedricvidal May 13, 2024
e12174c
RAFT Fix wrong chunk being checkpointed
cedricvidal May 14, 2024
3aec2a9
RAFT format.py Fix jsonl input type
cedricvidal May 14, 2024
b7ad22b
ARFT format.py more logs
cedricvidal May 14, 2024
20b3349
RAFT format.py Field names can be customized
cedricvidal May 14, 2024
6a4239e
RAFT format.py support for eval format
cedricvidal May 14, 2024
79ef352
RAFT format.py renamed answer to gold_answer
cedricvidal May 14, 2024
7cb803f
RAFT eval notebook works end to end
cedricvidal May 15, 2024
6222cb5
RAFT raft.py add --checkpoint-size arg
cedricvidal May 15, 2024
98d01e3
RAFT raft.py fix chunk and progress logging
cedricvidal May 15, 2024
26dbaa4
RAFT raft.py saving checkpoints in a single directory
cedricvidal May 15, 2024
de2bc2e
RAFT raft.py disbale datasets progress bars when saving checkpoints
cedricvidal May 15, 2024
56f469b
RAFT raft.py add llama template
cedricvidal May 15, 2024
ede07e0
RAFT raft.py llama prompt test 1
cedricvidal May 15, 2024
ade26e9
RAFT raft.py fixed missed change to formatter
cedricvidal May 15, 2024
b7422c7
RAFT raft.py specific template for llama to generate questions
cedricvidal May 15, 2024
2c5ac21
RAFT format.py stop keyword
cedricvidal May 16, 2024
b509584
RAFT eval.py stop keyword
cedricvidal May 16, 2024
3a125f5
RAFT raft.py gpt qa temaplte formatting
cedricvidal May 16, 2024
b48e390
RAFT format.py final_answer
cedricvidal May 16, 2024
c550d7a
RAFT raft.py save chunks to checkpoints folder
cedricvidal May 17, 2024
2de88b9
RAFT raft.py moved chunks checkpoint support to method
cedricvidal May 17, 2024
249a375
RAFT raft.py rename chunks func
cedricvidal May 18, 2024
a587d25
RAFT format.py more logging
cedricvidal May 18, 2024
a29ac6f
RAFT raft.py Moved checkpointing logic
cedricvidal May 18, 2024
c32a199
RAFT raft.py more checkpoint refactoring
cedricvidal May 18, 2024
3a09923
RAFT raft.py relying on checkpoint directories instead of state file
cedricvidal May 18, 2024
8b65cb6
RAFT raft.py removed globald ds
cedricvidal May 18, 2024
3f214ee
RAFT raft.py multi threading
cedricvidal May 18, 2024
9d5de71
RAFT raft.py --workers param
cedricvidal May 18, 2024
fc38956
RAFT raft.py ready to use
cedricvidal May 18, 2024
874fe06
RAFT raft.py tqdm for chunking
cedricvidal May 19, 2024
369d5c8
RAFT raft.py removed --fast param to simplify code
cedricvidal May 19, 2024
9f3c15f
RAFT raft.py formatting
cedricvidal May 19, 2024
9abc8d4
RAFT raft.py caching questions with HF's map
cedricvidal May 19, 2024
d2ed535
RAFT raft.py cot_answer multi threading
cedricvidal May 19, 2024
20c3636
RAFT raft.py renaming
cedricvidal May 19, 2024
82bddfa
RAFT raft.py cot answers now checkpointed
cedricvidal May 19, 2024
d2b699f
RAFT raft.py tuned completion max tokens
cedricvidal May 19, 2024
18370d2
RAFT raft.py checkpointing questions and answers in the same chunk loop
cedricvidal May 19, 2024
0aeeaa3
RAFT ignore output dir
cedricvidal May 19, 2024
f144bbc
RAFT ignore datasets dir
cedricvidal May 19, 2024
357002c
RAFT raft.py moved checkpointing suff and removed hf fingerprint code
cedricvidal May 19, 2024
e1cec66
RAFT raft.py ChatCompleter stats
cedricvidal May 19, 2024
a769582
RAFT raft.py tok/s
cedricvidal May 19, 2024
aec786e
RAFT raft.py renamed UsageStats
cedricvidal May 20, 2024
5289f10
RAFT raft.py setting initial in tqdm to avoid skewing the stats
cedricvidal May 20, 2024
5e19c15
RAFT update openai version
cedricvidal May 20, 2024
04950b1
RAFT raft.py avg tok/s
cedricvidal May 20, 2024
d1cf608
RAFT raft.py --auto-clean-checkpoints
cedricvidal May 20, 2024
68515fb
RAFT format.py filter out empty rows and added descriptions ot HF map…
cedricvidal May 20, 2024
6827862
RAFT format.py EvalDatasetFormatter
cedricvidal May 21, 2024
fba6cb2
RAFT raft.py Final log
cedricvidal May 21, 2024
7b753dd
RAFT raft.py more logging at the end
cedricvidal May 21, 2024
c13660a
RAFT eval.py retry and --workers param
cedricvidal May 22, 2024
44c11ab
RAFT client_utils.py stats support for completion
cedricvidal May 23, 2024
965c2e6
RAFT eval.py more robust
cedricvidal May 23, 2024
dced542
RAFT raft.py tqdm postfix
cedricvidal May 23, 2024
f6039ea
RAFT eval.py logging retry stats
cedricvidal May 23, 2024
d767416
RAFT eval.py fixed main thread silent fail in case of exception
cedricvidal May 23, 2024
fe2dfbe
RAFT fomrat.py 'answer' column is optional
cedricvidal May 23, 2024
47c23c5
RAFT Display PDFs in notebooks
cedricvidal May 29, 2024
4c9f963
RAFT Support for chat and completion models
cedricvidal Jun 18, 2024
70723f0
Ignore config.json
cedricvidal Jul 18, 2024
219a256
Updated README with new CLI parameters
cedricvidal Jul 19, 2024
87003af
Display default values in help + some help cleanup
cedricvidal Jul 19, 2024
af53b91
Display default values in format.py help
cedricvidal Jul 19, 2024
7535756
Moving notebooks to separate repo
cedricvidal Jul 19, 2024
57dc942
Fixed response Choice format
cedricvidal Jul 30, 2024
212476b
Logging resolved prefixed OPENAI env vars
cedricvidal Jul 31, 2024
52274f7
eval.py now takes the env prefix as param
cedricvidal Jul 31, 2024
fc8a645
Skipping chunks that raise content safety alerts
cedricvidal Aug 6, 2024
45a825e
More content filtering support int the question generation step
cedricvidal Aug 7, 2024
e58ca85
Fixed chat format and added default chat system prompt
cedricvidal Aug 7, 2024
8acf9c0
Fixed bug when format is completion
cedricvidal Aug 7, 2024
4a5c5cc
Support for Azure OpenAI Keyless and Managed Identity authentication
cedricvidal Aug 23, 2024
d99bbdb
Early stopping after QA threshold is met
cedricvidal Aug 25, 2024
e95d6f4
Trimming the generated dataset to qa_threshold if set
cedricvidal Aug 25, 2024
f1b84f6
Using questions to track progress if qa_threshold is set
cedricvidal Aug 25, 2024
b2b41f8
Merge branch 'main' into upstream-merge-prep
ShishirPatil Aug 27, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions raft/.gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
.venv/
output/
32 changes: 29 additions & 3 deletions raft/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,13 @@ pip install -r requirements.txt
```

Arguments:
- `--datapath` - the path at which the document is located
- `--datapath` - if a file, the path at which the document is located. If a folder, the path at which to load all documents
- `--output` - the path at which to save the dataset
- `--output-format` - the format of the output dataset. Defaults to `hf` for HuggingFace. Can be one of `hf`, `completion`, `chat`.
- `--output-format` - the format of the output dataset. Defaults to `hf` for HuggingFace. Can be one of `hf`, `completion`, `chat`, `eval`.
- `--output-type` - the type of the output dataset file. Defaults to `jsonl`. Can be one of `jsonl`, `parquet`.
- `--output-chat-system-prompt` - The system prompt to use when the output format is `chat`. Optional.
- `--output-completion-prompt-column` - The column (json field name) for the `prompt` / `instruction` when using the `completion` output format. Defaults to `prompt`.
- `--output-completion-completion-column` - The column (json field name) for the `completion` when using the `completion` output format. Defaults to `completion`.
- `--distractors` - the number of distractor documents to include per data point / triplet
- `--doctype` - the type of the document, must be one of the accepted doctypes
- currently accepted doctypes: `pdf`, `txt`, `json`, `api`
Expand All @@ -37,8 +39,11 @@ Arguments:
- `--openai_key` - your OpenAI key used to make queries to GPT-3.5 or GPT-4
- `--embedding-model` - The embedding model to use to encode documents chunks. Defaults to `text-embedding-ada-002`.
- `--completion-model` - The model to use to generate questions and answers. Defaults to `gpt-4`.
- `--fast` - Fast mode flag. By default, this flag is not included and the script runs in safe mode, where it saves checkpoint datasets, allowing the script to recover and continue where it left off in the case of an interruption. Include this flag to run RAFT without recovery.
- `--system-prompt-key` - The system prompt key to use to generate the dataset. Defaults to `gpt`. Can by one of `gpt`, `llama`.
- `--workers` - The number of worker threads to use to generate the dataset. Defaults to 2.
- `--auto-clean-checkpoints` - Whether to auto clean the checkpoints after the dataset is generated. Defaults to `false`.

*Note*: The `--fast` mode flag has been removed, checkpointing is now always active.

## Usage

Expand Down Expand Up @@ -219,6 +224,27 @@ python3 format.py --input output/data-00000-of-00001.arrow --output output.compl

```
python3 format.py --help

usage: format.py [-h] --input INPUT [--input-type {arrow,jsonl}] --output OUTPUT --output-format {hf,completion,chat,eval} [--output-type {parquet,jsonl}] [--output-chat-system-prompt OUTPUT_CHAT_SYSTEM_PROMPT] [--output-completion-prompt-column OUTPUT_COMPLETION_PROMPT_COLUMN] [--output-completion-completion-column OUTPUT_COMPLETION_COMPLETION_COLUMN] [--output-completion-stop OUTPUT_COMPLETION_STOP]

options:
-h, --help show this help message and exit
--input INPUT Input HuggingFace dataset file (default: None)
--input-type {arrow,jsonl}
Format of the input dataset. Defaults to arrow. (default: arrow)
--output OUTPUT Output file (default: None)
--output-format {hf,completion,chat,eval}
Format to convert the dataset to (default: None)
--output-type {parquet,jsonl}
Type to export the dataset to. Defaults to jsonl. (default: jsonl)
--output-chat-system-prompt OUTPUT_CHAT_SYSTEM_PROMPT
The system prompt to use when the output format is chat (default: None)
--output-completion-prompt-column OUTPUT_COMPLETION_PROMPT_COLUMN
The prompt column name to use for the completion format (default: prompt)
--output-completion-completion-column OUTPUT_COMPLETION_COMPLETION_COLUMN
The completion column name to use for the completion format (default: completion)
--output-completion-stop OUTPUT_COMPLETION_STOP
The stop keyword to use for the completion format (default: <STOP>)
```

**Note**: If fine tuning a chat model, then you need to use `--output-format chat` and optionally add the `--output-chat-system-prompt` parameter to configure the system prompt included in the dataset.
Expand Down
77 changes: 77 additions & 0 deletions raft/checkpointing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from dataclasses import dataclass
from pathlib import Path
from typing import List
from datasets import Dataset, concatenate_datasets
import logging
import shutil

logger = logging.getLogger("raft")

@dataclass
class Checkpoint:
path: Path
num: int

def load(self) -> Dataset:
return Dataset.load_from_disk(self.path)

def __lt__(self, other: 'Checkpoint') -> bool:
return self.num < other.num

def __eq__(self, other: 'Checkpoint') -> bool:
return self.num == other.num

def __hash__(self) -> int:
return hash(self.num)

class Checkpointing:

def __init__(self, checkpoints_dir: Path) -> None:
self.checkpoints_dir = checkpoints_dir

def missing_checkpoints(self, num) -> List[int]:
return [n for n in range(0, num) if not (self.checkpoints_dir / f"checkpoint-{n}").exists()]

def save_checkpoint(self, ds: Dataset, num: int):
checkpoint_path = self.checkpoints_dir / ("checkpoint-" + str(num))
ds.save_to_disk(checkpoint_path)

def load_checkpoint(self, num: int):
checkpoint_path = self.checkpoints_dir / ("checkpoint-" + str(num))
if checkpoint_path.exists():
return Dataset.load_from_disk(checkpoint_path)
return None

def get_checkpoints(self) -> List[Checkpoint]:
checkpoints = []
if not self.checkpoints_dir.exists():
return checkpoints
for dir_path in self.checkpoints_dir.iterdir():
if dir_path.is_dir() and dir_path.name.startswith("checkpoint-"):
num = int(dir_path.name.split("-")[1])
checkpoints.append(Checkpoint(dir_path, num))
return checkpoints

def has_checkpoints(self) -> bool:
return len(self.get_checkpoints()) > 0

def collect_checkpoints(self) -> Dataset:
ds_list = list([checkpoint.load() for checkpoint in self.get_checkpoints()])
ds = concatenate_datasets(ds_list)
return ds

def delete_checkpoints(self):
shutil.rmtree(self.checkpoints_dir)

def checkpointed(checkpointing: Checkpointing):
def wrapped(func):
def wrapper(chunk_id, *args, **kwargs):
ds = checkpointing.load_checkpoint(chunk_id)
if ds:
return ds
ds = func(chunk_id=chunk_id, *args, **kwargs)
if ds.num_rows > 0:
checkpointing.save_checkpoint(ds, chunk_id)
return ds
return wrapper
return wrapped
128 changes: 119 additions & 9 deletions raft/client_utils.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,29 @@
from abc import ABC
from typing import Any
from dotenv import load_dotenv
from langchain_openai import OpenAIEmbeddings, AzureOpenAIEmbeddings
from openai import AzureOpenAI, OpenAI
import logging
from env_config import read_env_config, set_env
from os import environ
from os import environ, getenv
import time
from threading import Lock
from azure.identity import DefaultAzureCredential, ManagedIdentityCredential
from azure.identity import get_bearer_token_provider

logger = logging.getLogger("client_utils")

load_dotenv() # take environment variables from .env.
logger = logging.getLogger("client_utils")

def build_openai_client(**kwargs: Any) -> OpenAI:
def build_openai_client(env_prefix : str = "COMPLETION", **kwargs: Any) -> OpenAI:
"""
Build OpenAI client based on the environment variables.
"""

env = read_env_config("COMPLETION")
kwargs = _remove_empty_values(kwargs)
env = read_env_config(env_prefix)
with set_env(**env):
if is_azure():
client = AzureOpenAI(**kwargs)
auth_args = _get_azure_auth_client_args()
client = AzureOpenAI(**auth_args, **kwargs)
else:
client = OpenAI(**kwargs)
return client
Expand All @@ -28,19 +33,124 @@ def build_langchain_embeddings(**kwargs: Any) -> OpenAIEmbeddings:
Build OpenAI embeddings client based on the environment variables.
"""

kwargs = _remove_empty_values(kwargs)
env = read_env_config("EMBEDDING")

with set_env(**env):
if is_azure():
client = AzureOpenAIEmbeddings(**kwargs)
auth_args = _get_azure_auth_client_args()
client = AzureOpenAIEmbeddings(**auth_args, **kwargs)
else:
client = OpenAIEmbeddings(**kwargs)
return client

def _remove_empty_values(d: dict) -> dict:
return {k: v for k, v in d.items() if v is not None}

def _get_azure_auth_client_args() -> dict:
"""Handle Azure OpenAI Keyless, Managed Identity and Key based authentication
https://techcommunity.microsoft.com/t5/microsoft-developer-community/using-keyless-authentication-with-azure-openai/ba-p/4111521
"""
client_args = {}
if getenv("AZURE_OPENAI_KEY"):
logger.info("Using Azure OpenAI Key based authentication")
client_args["api_key"] = getenv("AZURE_OPENAI_KEY")
else:
if client_id := getenv("AZURE_OPENAI_CLIENT_ID"):
# Authenticate using a user-assigned managed identity on Azure
logger.info("Using Azure OpenAI Managed Identity Keyless authentication")
azure_credential = ManagedIdentityCredential(client_id=client_id)
else:
# Authenticate using the default Azure credential chain
logger.info("Using Azure OpenAI Default Azure Credential Keyless authentication")
azure_credential = DefaultAzureCredential()

client_args["azure_ad_token_provider"] = get_bearer_token_provider(
azure_credential, "https://cognitiveservices.azure.com/.default")
client_args["api_version"] = getenv("AZURE_OPENAI_API_VERSION") or "2024-02-15-preview"
client_args["azure_endpoint"] = getenv("AZURE_OPENAI_ENDPOINT")
client_args["azure_deployment"] = getenv("AZURE_OPENAI_DEPLOYMENT")
return client_args

def is_azure():
azure = "AZURE_OPENAI_ENDPOINT" in environ or "AZURE_OPENAI_KEY" in environ or "AZURE_OPENAI_AD_TOKEN" in environ
if azure:
logger.debug("Using Azure OpenAI environment variables")
else:
logger.debug("Using OpenAI environment variables")
return azure

def safe_min(a: Any, b: Any) -> Any:
if a is None:
return b
if b is None:
return a
return min(a, b)

def safe_max(a: Any, b: Any) -> Any:
if a is None:
return b
if b is None:
return a
return max(a, b)

class UsageStats:
def __init__(self) -> None:
self.start = time.time()
self.completion_tokens = 0
self.prompt_tokens = 0
self.total_tokens = 0
self.end = None
self.duration = 0
self.calls = 0

def __add__(self, other: 'UsageStats') -> 'UsageStats':
stats = UsageStats()
stats.start = safe_min(self.start, other.start)
stats.end = safe_max(self.end, other.end)
stats.completion_tokens = self.completion_tokens + other.completion_tokens
stats.prompt_tokens = self.prompt_tokens + other.prompt_tokens
stats.total_tokens = self.total_tokens + other.total_tokens
stats.duration = self.duration + other.duration
stats.calls = self.calls + other.calls
return stats

class StatsCompleter(ABC):
def __init__(self, create_func):
self.create_func = create_func
self.stats = None
self.lock = Lock()

def __call__(self, *args: Any, **kwds: Any) -> Any:
response = self.create_func(*args, **kwds)
self.lock.acquire()
try:
if not self.stats:
self.stats = UsageStats()
self.stats.completion_tokens += response.usage.completion_tokens
self.stats.prompt_tokens += response.usage.prompt_tokens
self.stats.total_tokens += response.usage.total_tokens
self.stats.calls += 1
return response
finally:
self.lock.release()

def get_stats_and_reset(self) -> UsageStats:
self.lock.acquire()
try:
end = time.time()
stats = self.stats
if stats:
stats.end = end
stats.duration = end - self.stats.start
self.stats = None
return stats
finally:
self.lock.release()

class ChatCompleter(StatsCompleter):
def __init__(self, client):
super().__init__(client.chat.completions.create)

class CompletionsCompleter(StatsCompleter):
def __init__(self, client):
super().__init__(client.completions.create)
19 changes: 19 additions & 0 deletions raft/env_config.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,30 @@
import contextlib
import os
import logging

logger = logging.getLogger("env_config")

# List of environment variables prefixes that are allowed to be used for configuration.
env_prefix_whitelist = [
'OPENAI',
'AZURE_OPENAI'
]

def _obfuscate(secret):
l = len(secret)
return '.' * (l - 4) + secret[-4:]

def _log_env(use_prefix: str, env: dict):
"""
Logs each name value pair of the given environment. If the name indicates that it might store a secret such as an API key, then obfuscate the value.
"""
log_prefix = f"'{use_prefix}'" if use_prefix else "no"
logger.info(f"Resolved OpenAI env vars with {log_prefix} prefix:")
for key, value in env.items():
if any(prefix in key for prefix in ['KEY', 'SECRET', 'TOKEN']):
value = _obfuscate(value)
logger.info(f" - {key}={value}")

def read_env_config(use_prefix: str, env: dict = os.environ) -> str:
"""
Read whitelisted environment variables and return them in a dictionary.
Expand All @@ -15,6 +33,7 @@ def read_env_config(use_prefix: str, env: dict = os.environ) -> str:
config = {}
for prefix in [None, use_prefix]:
read_env_config_prefixed(prefix, config, env)
_log_env(use_prefix, config)
return config

def read_env_config_prefixed(use_prefix: str, config: dict, env: dict = os.environ) -> str:
Expand Down
Loading