Skip to content

Commit

Permalink
added train files
Browse files Browse the repository at this point in the history
  • Loading branch information
TristanThrush committed Nov 7, 2024
1 parent 165d8c6 commit 92ab182
Show file tree
Hide file tree
Showing 11 changed files with 687 additions and 271 deletions.

This file was deleted.

24 changes: 17 additions & 7 deletions examples/get_error_and_bpb/chunk_pretraining_data_sample.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from datasets import load_dataset, Features, Value
from datasets import load_dataset, load_from_disk, Features, Value
from transformers import AutoTokenizer
from ast import literal_eval
from tqdm import tqdm
Expand All @@ -17,11 +17,20 @@
with open(args.config, "r") as file:
config = SimpleNamespace(**yaml.safe_load(file))

ds = load_dataset(
config.hf_name,
name=config.subset,
split=config.split,
)
config_load_from_disk = getattr(config, "load_from_disk", False)
if config_load_from_disk:
ds = load_from_disk(
config.hf_name,
#name=config.subset,
#split=config.split,
)[config.split]
else:
ds = load_dataset(
config.hf_name,
name=config.subset,
split=config.split,
)

if config.subsample_ratio < 1:
sample_size = int(config.subsample_ratio * len(ds))
ds = ds.select(range(sample_size))
Expand All @@ -46,7 +55,8 @@
},
num_proc=config.num_proc,
)
ds = ds.rename_column(config.domain_column, "domain")
if config.domain_column != "domain":
ds = ds.rename_column(config.domain_column, "domain")


if config.enforce_pages_per_domain:
Expand Down
73 changes: 73 additions & 0 deletions examples/pretrain_llm/generate_fasttext_sample_weights.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import fasttext
from datasets import load_from_disk
import numpy as np
import ast
import argparse
from types import SimpleNamespace
import yaml
import os

parser = argparse.ArgumentParser()
parser.add_argument('--config')
args = parser.parse_args()

with open(args.config, "r") as file:
config = SimpleNamespace(**yaml.safe_load(file))

os.makedirs(config.output_dir, exist_ok=True)

for target in config.targets:
fasttext_model_path = target["fasttext_model_path"]
output_name = target["output_name"]

model = fasttext.load_model(fasttext_model_path) #'openhermes_reddit_eli5_vs_rw_v2_bigram_200k_train.bin')
total_labels = len(model.get_labels())

ds = load_from_disk(config.hf_dataset)[config.split]

# Run fasttext high-quality (hq) classifier
def classify_text(example):
text = example[config.text_column].replace("\n", " ")
labels, probabilities = model.predict(text, k=total_labels)
if '__label__hq' in labels:
return probabilities[labels.index('__label__hq')]
else:
return probabilities[labels.index('__label__include')]

ds = ds.map(lambda example: {"fasttext_hq_prob": classify_text(example), "doc_id": example[config.id_column]}, remove_columns=ds.column_names, num_proc=64)

# Build a dict of doc id to fasttext hq prob and domain (domain is only included here for debugging)
doc_id_to_fasttext_hq_prob = {}
def build_fasttext_dict(example):
doc_id_to_fasttext_hq_prob[example["doc_id"]] = example["fasttext_hq_prob"]

ds.map(build_fasttext_dict)

# Get page name to index in the ordered doc ids for pretraining
page_name_to_index = {}
ordered_page_names = np.load(config.file_prefix + "_id.npy")
for i, doc_id in enumerate(ordered_page_names):
page_name_to_index[doc_id] = i

# Create sampling distribution by iteratively including the highest hq prob pages until we match or exceed the desired number of tokens
ordered_token_counts = np.load(config.file_prefix + "_len.npy")
page_names_sorted_by_fasttext_hq_prob = sorted(doc_id_to_fasttext_hq_prob.items(), key=lambda item: item[1], reverse=True)
print("highest hq prob entries:", page_names_sorted_by_fasttext_hq_prob[:20])
print("lowest hq prob entries:", page_names_sorted_by_fasttext_hq_prob[-20:])

current_token_count = 0
num_included_pages = 0
sampling_wt = np.zeros(ordered_page_names.shape)
for doc_id, _ in page_names_sorted_by_fasttext_hq_prob:
if doc_id in page_name_to_index:
doc_token_count = ordered_token_counts[page_name_to_index[doc_id]]
sampling_wt[page_name_to_index[doc_id]] = doc_token_count
current_token_count += doc_token_count
num_included_pages += 1
if current_token_count >= config.desired_token_count:
print(f"created sampling wt. Num included pages={num_included_pages}, Num tokens={current_token_count}")
break

sampling_wt /= sampling_wt.sum()

np.save(os.path.join(config.output_dir, output_name), sampling_wt)
56 changes: 56 additions & 0 deletions examples/pretrain_llm/generate_mmap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from datasets import load_dataset, load_from_disk
from mmap_utils import tokenize_and_mmap, get_dataset
from types import SimpleNamespace
from transformers import AutoTokenizer
import yaml
import numpy as np
import pickle
import argparse
import os

parser = argparse.ArgumentParser()
parser.add_argument("--config")
args = parser.parse_args()

with open(args.config, "r") as file:
config = SimpleNamespace(**yaml.safe_load(file))

def wrap_dataset_iterator(ds, fields):
for data in ds:
yield [data[i] for i in fields]

if __name__ == '__main__':
tokenizer = AutoTokenizer.from_pretrained(config.hf_tokenizer)
if config.from_disk:
ds = load_from_disk(config.ds_path)[config.split]
ds_subset = ds
else:
ds = load_dataset(config.ds_path, config.split)
os.makedirs(config.output_dir, exist_ok=True)
file_prefix = os.path.join(config.output_dir, config.output_file_prefix)
max_tokens = 100*(10**9) # 100 bil tokens

if config.domain is not None:
doc_id_to_domain = {}
def build_doc_id_to_domain(example):
doc_id_to_domain[example[config.id_column]] = example[config.domain]
ds_subset.map(build_doc_id_to_domain)
with open(file_prefix + "_doc_id_to_domain.pkl", 'wb') as f:
pickle.dump(doc_id_to_domain, f)

def tokenize_function(example):
return tokenizer(example[config.text_column], padding="do_not_pad", truncation=False)

# Apply tokenization and remove original columns
ds_subset = ds_subset.map(tokenize_function, remove_columns=[name for name in ds_subset.column_names if name != config.id_column], num_proc=config.num_proc)

tokenize_and_mmap(wrap_dataset_iterator(ds_subset, [config.id_column, 'input_ids']), tokenizer, max_tokens, config.context_length, file_prefix)
len_vecs = np.load(file_prefix + "_len.npy")
prob_vec = len_vecs / np.sum(len_vecs)
dataset = get_dataset(prob_vector=prob_vec, ctx_len=ctx_len, memmaped_file=file_prefix + ".mmap", start_map=np.load(file_prefix + "_start.npy"), len_map=np.load(file_prefix + "_len.npy"), max_tokens=max_tokens)

for i, data in enumerate(dataset):
print(tokenizer.decode(data["input_ids"]))
print('--------')
if i > 10:
break
11 changes: 11 additions & 0 deletions examples/pretrain_llm/mmap_configs/yangjun_synthetic_config.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
output_file_prefix: data
output_dir: mmap_datasets/yangjun_synthetic
ds_path: ../get_error_and_bpb/pre_chunked_datasets/yangjun_synthetic
from_disk: true
hf_tokenizer: EleutherAI/pythia-160m
domain: domain
text_column: latent
id_column: url
split: train
context_length: 128
num_proc: 16
132 changes: 132 additions & 0 deletions examples/pretrain_llm/mmap_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
from transformers import AutoTokenizer
import torch
import numpy as np
from datasets import IterableDataset
from tqdm import tqdm
import concurrent.futures

def tokenize_and_mmap(seq_id_iterator: list[(str, list[int])], tokenizer: AutoTokenizer, max_tokens:int, ctx_len: int, file_prefix: str) -> (np.array, np.array, np.array, list[str]):
# given a list of string to tokenize, tokenize each one and write to a memmap file in order
# return the memmap array and the starting index and the length of each piece of text
tokenized_mmap_file = np.memmap(file_prefix+'.mmap', dtype='int32', mode='w+', shape=(max_tokens))
len_list = []
cur_idx = 0
id_selected = []
for id, tok_list in tqdm(seq_id_iterator):
tokens = np.array(tok_list)
total_tokens = tokens.size
if cur_idx + total_tokens > max_tokens:
# we could add this last truncated bit, but forget it - messes up indexing.
#truncated_token_ct = max_tokens - cur_idx
#tokenized_mmap_file[cur_idx:] = tokens[:truncated_token_ct]
#len_list.append(truncated_token_ct)
break
if total_tokens >= ctx_len:
tokenized_mmap_file[cur_idx:cur_idx+total_tokens] = tokens
cur_idx += total_tokens
len_list.append(tokens.size)
id_selected.append(id)
if len(len_list) % 100000 == 0:
# periodically flush writes to disk, clear memory
tokenized_mmap_file.flush()
tokenized_mmap_file = np.memmap(file_prefix+'.mmap', dtype='int32', mode='r+', shape=(max_tokens))
start_index = np.array([0] + np.cumsum(len_list)[:-1].tolist())
len_array = np.array(len_list)
# dump both arrays
np.save(file_prefix+'_start.npy', start_index)
np.save(file_prefix+'_len.npy', len_array)
np.save(file_prefix+'_metadata.npy', np.array(max_tokens))
np.save(file_prefix+'_id.npy', np.array(id_selected))
return tokenized_mmap_file, start_index, len_array, id_selected

def sample_from_vec(prob_vector: np.array, batch_size: int, ctx_len: int, memmapped_array: np.array, start_map: np.array, len_map: np.array, gen: np.random.Generator = np.random.Generator(np.random.PCG64())):
# samples tokens in a weighted way from documents.
# samples a doc proportionally to prob_vector.
# within each doc, sample a window of ctx_len uniformly at random.
# returns the sampled batch of token indices
assert(np.min(len_map) >= ctx_len) # can kill this if slow..
# get the document ids
#doc_ids = np.array(random.choices(range(len(prob_vector)), weights=prob_vector, k=batch_size)) #random.choices is slightly faster than numpy
doc_ids = gen.choice(len(prob_vector), p=prob_vector, size=batch_size)
# now get the offsets -
offset_ids = np.random.randint(len_map[doc_ids] - ctx_len + 1)
start_points = start_map[doc_ids] + offset_ids
# do some fancy reshaping to do vectorized indexing
flattened_idx = np.add.outer(start_points, np.arange(ctx_len)).reshape(ctx_len*batch_size)
sampled_batch = memmapped_array[flattened_idx].reshape(batch_size, ctx_len)
return torch.LongTensor(sampled_batch), torch.ones(sampled_batch.shape)

def get_dataset(prob_vector:np.array, ctx_len: int, memmaped_file: str, start_map: np.array, len_map: np.array, max_tokens: int, batch_size = 10000):
def gen():
rng = np.random.Generator(np.random.PCG64())
while True:
temp_memmap = np.memmap(memmaped_file, dtype='int32', mode='r', shape=(max_tokens)) # reinitialize memmap for memory
sampled_batches, masks = sample_from_vec(prob_vector, batch_size, ctx_len, temp_memmap, start_map, len_map, rng)
for i in range(batch_size):
yield {
"input_ids": sampled_batches[i,:].squeeze(),
"labels": sampled_batches[i,:].squeeze(),
"attention_mask": masks[i,:].squeeze()
}
print('get_dataset')
return IterableDataset.from_generator(gen)

import time
def get_dataset_async(prob_vector: np.array, ctx_len: int, memmaped_file: str, start_map: np.array, len_map: np.array,
max_tokens: int, batch_size = 10000):
# async version of the above - used to overlap reads and GPU computation
def gen():
rng = np.random.Generator(np.random.PCG64())
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
future_batch = executor.submit(sample_from_vec, prob_vector, batch_size, ctx_len,
np.memmap(memmaped_file, dtype='int32', mode='r', shape=(max_tokens)),
start_map, len_map, rng)

while True:
start = time.time()
# Wait for the future to complete and get the result
sampled_batches, masks = future_batch.result()

# Submit the next batch generation
future_batch = executor.submit(sample_from_vec, prob_vector, batch_size, ctx_len,
np.memmap(memmaped_file, dtype='int32', mode='r', shape=(max_tokens)),
start_map, len_map, rng)

end = time.time()
print('batch overhead '+str(end-start)+'(s)')
for i in range(batch_size):
yield {
"input_ids": sampled_batches[i,:].squeeze(),
"labels": sampled_batches[i,:].squeeze(),
"attention_mask": masks[i,:].squeeze()
}

print('get_dataset')
return IterableDataset.from_generator(gen)

# Plan for without replacement sampler:
# Do modulo rank
# Do modulo seq len
# Convert prob vector into token counts
# Store a dict of the remaining token counts to sample for each page
# Once the remaining token counts gets to zero, remove it from the dict
# If the remaining token count dict is empty, start again. This means we are doing more than one epoch.
# If the remaining token counts is smaller than seq len, keep sampling from other pages, adding an eos token inbetween, until you are at seq len.



if __name__ == '__main__':
model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
print('Tokenizer fastness:'+str(tokenizer.is_fast))
max_tokens = 1024
test_strings = ["the quick brown fox jumps over the lazy dog", "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum."]
print(tokenizer(test_strings[0], return_tensors="pt"))
tokenized_list = [tokenizer(seq+tokenizer.eos_token, return_tensors="pt")['input_ids'][0] for seq in test_strings]
merged_seq, start_map, len_map, id_list = tokenize_and_mmap(enumerate(tokenized_list), tokenizer, max_tokens,4, 'test')
dataset = get_dataset_async(np.array([0.1, 0.9])[id_list], 4, 'test.mmap', start_map, len_map, max_tokens)
for i, data in enumerate(dataset):
print(tokenizer.decode(data['input_ids'].numpy().tolist()))
#_ = tokenizer.decode(data[0])
if i > 100:
break
2 changes: 2 additions & 0 deletions examples/pretrain_llm/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
peft
wandb
16 changes: 16 additions & 0 deletions examples/pretrain_llm/sample_weights_configs/yangjun_synthetic.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
file_prefix: yangjun_synthetic/data
desired_token_count: 320000000
output_dir: sample_weights
targets:
- output_name: sciq_yangjun_synthetic.npy
fasttext_model_path: ../get_fasttext_filter/fasttext_models/sciq_yangjun_synthetic.bin
- output_name: piqa_yangjun_synthetic.npy
fasttext_model_path: ../get_fasttext_filter/fasttext_models/piqa_yangjun_synthetic.bin
- output_name: arc_easy_yangjun_synthetic.npy
fasttext_model_path: ../get_fasttext_filter/fasttext_models/arc_easy_yangjun_synthetic.bin
- output_name: lambada_yangjun_synthetic.npy
fasttext_model_path: ../get_fasttext_filter/fasttext_models/lambada_yangjun_synthetic.bin
id_column: url
split: train
text_column: latent
hf_dataset: ../get_error_and_bpb/pre_chunked_datasets/yangjun_synthetic
Loading

0 comments on commit 92ab182

Please sign in to comment.