-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
165d8c6
commit 92ab182
Showing
11 changed files
with
687 additions
and
271 deletions.
There are no files selected for viewing
264 changes: 0 additions & 264 deletions
264
examples/get_error_and_bpb/bpb_csvs/chunked_synthetic_sample_bpb_id.csv
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
import fasttext | ||
from datasets import load_from_disk | ||
import numpy as np | ||
import ast | ||
import argparse | ||
from types import SimpleNamespace | ||
import yaml | ||
import os | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--config') | ||
args = parser.parse_args() | ||
|
||
with open(args.config, "r") as file: | ||
config = SimpleNamespace(**yaml.safe_load(file)) | ||
|
||
os.makedirs(config.output_dir, exist_ok=True) | ||
|
||
for target in config.targets: | ||
fasttext_model_path = target["fasttext_model_path"] | ||
output_name = target["output_name"] | ||
|
||
model = fasttext.load_model(fasttext_model_path) #'openhermes_reddit_eli5_vs_rw_v2_bigram_200k_train.bin') | ||
total_labels = len(model.get_labels()) | ||
|
||
ds = load_from_disk(config.hf_dataset)[config.split] | ||
|
||
# Run fasttext high-quality (hq) classifier | ||
def classify_text(example): | ||
text = example[config.text_column].replace("\n", " ") | ||
labels, probabilities = model.predict(text, k=total_labels) | ||
if '__label__hq' in labels: | ||
return probabilities[labels.index('__label__hq')] | ||
else: | ||
return probabilities[labels.index('__label__include')] | ||
|
||
ds = ds.map(lambda example: {"fasttext_hq_prob": classify_text(example), "doc_id": example[config.id_column]}, remove_columns=ds.column_names, num_proc=64) | ||
|
||
# Build a dict of doc id to fasttext hq prob and domain (domain is only included here for debugging) | ||
doc_id_to_fasttext_hq_prob = {} | ||
def build_fasttext_dict(example): | ||
doc_id_to_fasttext_hq_prob[example["doc_id"]] = example["fasttext_hq_prob"] | ||
|
||
ds.map(build_fasttext_dict) | ||
|
||
# Get page name to index in the ordered doc ids for pretraining | ||
page_name_to_index = {} | ||
ordered_page_names = np.load(config.file_prefix + "_id.npy") | ||
for i, doc_id in enumerate(ordered_page_names): | ||
page_name_to_index[doc_id] = i | ||
|
||
# Create sampling distribution by iteratively including the highest hq prob pages until we match or exceed the desired number of tokens | ||
ordered_token_counts = np.load(config.file_prefix + "_len.npy") | ||
page_names_sorted_by_fasttext_hq_prob = sorted(doc_id_to_fasttext_hq_prob.items(), key=lambda item: item[1], reverse=True) | ||
print("highest hq prob entries:", page_names_sorted_by_fasttext_hq_prob[:20]) | ||
print("lowest hq prob entries:", page_names_sorted_by_fasttext_hq_prob[-20:]) | ||
|
||
current_token_count = 0 | ||
num_included_pages = 0 | ||
sampling_wt = np.zeros(ordered_page_names.shape) | ||
for doc_id, _ in page_names_sorted_by_fasttext_hq_prob: | ||
if doc_id in page_name_to_index: | ||
doc_token_count = ordered_token_counts[page_name_to_index[doc_id]] | ||
sampling_wt[page_name_to_index[doc_id]] = doc_token_count | ||
current_token_count += doc_token_count | ||
num_included_pages += 1 | ||
if current_token_count >= config.desired_token_count: | ||
print(f"created sampling wt. Num included pages={num_included_pages}, Num tokens={current_token_count}") | ||
break | ||
|
||
sampling_wt /= sampling_wt.sum() | ||
|
||
np.save(os.path.join(config.output_dir, output_name), sampling_wt) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
from datasets import load_dataset, load_from_disk | ||
from mmap_utils import tokenize_and_mmap, get_dataset | ||
from types import SimpleNamespace | ||
from transformers import AutoTokenizer | ||
import yaml | ||
import numpy as np | ||
import pickle | ||
import argparse | ||
import os | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--config") | ||
args = parser.parse_args() | ||
|
||
with open(args.config, "r") as file: | ||
config = SimpleNamespace(**yaml.safe_load(file)) | ||
|
||
def wrap_dataset_iterator(ds, fields): | ||
for data in ds: | ||
yield [data[i] for i in fields] | ||
|
||
if __name__ == '__main__': | ||
tokenizer = AutoTokenizer.from_pretrained(config.hf_tokenizer) | ||
if config.from_disk: | ||
ds = load_from_disk(config.ds_path)[config.split] | ||
ds_subset = ds | ||
else: | ||
ds = load_dataset(config.ds_path, config.split) | ||
os.makedirs(config.output_dir, exist_ok=True) | ||
file_prefix = os.path.join(config.output_dir, config.output_file_prefix) | ||
max_tokens = 100*(10**9) # 100 bil tokens | ||
|
||
if config.domain is not None: | ||
doc_id_to_domain = {} | ||
def build_doc_id_to_domain(example): | ||
doc_id_to_domain[example[config.id_column]] = example[config.domain] | ||
ds_subset.map(build_doc_id_to_domain) | ||
with open(file_prefix + "_doc_id_to_domain.pkl", 'wb') as f: | ||
pickle.dump(doc_id_to_domain, f) | ||
|
||
def tokenize_function(example): | ||
return tokenizer(example[config.text_column], padding="do_not_pad", truncation=False) | ||
|
||
# Apply tokenization and remove original columns | ||
ds_subset = ds_subset.map(tokenize_function, remove_columns=[name for name in ds_subset.column_names if name != config.id_column], num_proc=config.num_proc) | ||
|
||
tokenize_and_mmap(wrap_dataset_iterator(ds_subset, [config.id_column, 'input_ids']), tokenizer, max_tokens, config.context_length, file_prefix) | ||
len_vecs = np.load(file_prefix + "_len.npy") | ||
prob_vec = len_vecs / np.sum(len_vecs) | ||
dataset = get_dataset(prob_vector=prob_vec, ctx_len=ctx_len, memmaped_file=file_prefix + ".mmap", start_map=np.load(file_prefix + "_start.npy"), len_map=np.load(file_prefix + "_len.npy"), max_tokens=max_tokens) | ||
|
||
for i, data in enumerate(dataset): | ||
print(tokenizer.decode(data["input_ids"])) | ||
print('--------') | ||
if i > 10: | ||
break |
11 changes: 11 additions & 0 deletions
11
examples/pretrain_llm/mmap_configs/yangjun_synthetic_config.yml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
output_file_prefix: data | ||
output_dir: mmap_datasets/yangjun_synthetic | ||
ds_path: ../get_error_and_bpb/pre_chunked_datasets/yangjun_synthetic | ||
from_disk: true | ||
hf_tokenizer: EleutherAI/pythia-160m | ||
domain: domain | ||
text_column: latent | ||
id_column: url | ||
split: train | ||
context_length: 128 | ||
num_proc: 16 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
from transformers import AutoTokenizer | ||
import torch | ||
import numpy as np | ||
from datasets import IterableDataset | ||
from tqdm import tqdm | ||
import concurrent.futures | ||
|
||
def tokenize_and_mmap(seq_id_iterator: list[(str, list[int])], tokenizer: AutoTokenizer, max_tokens:int, ctx_len: int, file_prefix: str) -> (np.array, np.array, np.array, list[str]): | ||
# given a list of string to tokenize, tokenize each one and write to a memmap file in order | ||
# return the memmap array and the starting index and the length of each piece of text | ||
tokenized_mmap_file = np.memmap(file_prefix+'.mmap', dtype='int32', mode='w+', shape=(max_tokens)) | ||
len_list = [] | ||
cur_idx = 0 | ||
id_selected = [] | ||
for id, tok_list in tqdm(seq_id_iterator): | ||
tokens = np.array(tok_list) | ||
total_tokens = tokens.size | ||
if cur_idx + total_tokens > max_tokens: | ||
# we could add this last truncated bit, but forget it - messes up indexing. | ||
#truncated_token_ct = max_tokens - cur_idx | ||
#tokenized_mmap_file[cur_idx:] = tokens[:truncated_token_ct] | ||
#len_list.append(truncated_token_ct) | ||
break | ||
if total_tokens >= ctx_len: | ||
tokenized_mmap_file[cur_idx:cur_idx+total_tokens] = tokens | ||
cur_idx += total_tokens | ||
len_list.append(tokens.size) | ||
id_selected.append(id) | ||
if len(len_list) % 100000 == 0: | ||
# periodically flush writes to disk, clear memory | ||
tokenized_mmap_file.flush() | ||
tokenized_mmap_file = np.memmap(file_prefix+'.mmap', dtype='int32', mode='r+', shape=(max_tokens)) | ||
start_index = np.array([0] + np.cumsum(len_list)[:-1].tolist()) | ||
len_array = np.array(len_list) | ||
# dump both arrays | ||
np.save(file_prefix+'_start.npy', start_index) | ||
np.save(file_prefix+'_len.npy', len_array) | ||
np.save(file_prefix+'_metadata.npy', np.array(max_tokens)) | ||
np.save(file_prefix+'_id.npy', np.array(id_selected)) | ||
return tokenized_mmap_file, start_index, len_array, id_selected | ||
|
||
def sample_from_vec(prob_vector: np.array, batch_size: int, ctx_len: int, memmapped_array: np.array, start_map: np.array, len_map: np.array, gen: np.random.Generator = np.random.Generator(np.random.PCG64())): | ||
# samples tokens in a weighted way from documents. | ||
# samples a doc proportionally to prob_vector. | ||
# within each doc, sample a window of ctx_len uniformly at random. | ||
# returns the sampled batch of token indices | ||
assert(np.min(len_map) >= ctx_len) # can kill this if slow.. | ||
# get the document ids | ||
#doc_ids = np.array(random.choices(range(len(prob_vector)), weights=prob_vector, k=batch_size)) #random.choices is slightly faster than numpy | ||
doc_ids = gen.choice(len(prob_vector), p=prob_vector, size=batch_size) | ||
# now get the offsets - | ||
offset_ids = np.random.randint(len_map[doc_ids] - ctx_len + 1) | ||
start_points = start_map[doc_ids] + offset_ids | ||
# do some fancy reshaping to do vectorized indexing | ||
flattened_idx = np.add.outer(start_points, np.arange(ctx_len)).reshape(ctx_len*batch_size) | ||
sampled_batch = memmapped_array[flattened_idx].reshape(batch_size, ctx_len) | ||
return torch.LongTensor(sampled_batch), torch.ones(sampled_batch.shape) | ||
|
||
def get_dataset(prob_vector:np.array, ctx_len: int, memmaped_file: str, start_map: np.array, len_map: np.array, max_tokens: int, batch_size = 10000): | ||
def gen(): | ||
rng = np.random.Generator(np.random.PCG64()) | ||
while True: | ||
temp_memmap = np.memmap(memmaped_file, dtype='int32', mode='r', shape=(max_tokens)) # reinitialize memmap for memory | ||
sampled_batches, masks = sample_from_vec(prob_vector, batch_size, ctx_len, temp_memmap, start_map, len_map, rng) | ||
for i in range(batch_size): | ||
yield { | ||
"input_ids": sampled_batches[i,:].squeeze(), | ||
"labels": sampled_batches[i,:].squeeze(), | ||
"attention_mask": masks[i,:].squeeze() | ||
} | ||
print('get_dataset') | ||
return IterableDataset.from_generator(gen) | ||
|
||
import time | ||
def get_dataset_async(prob_vector: np.array, ctx_len: int, memmaped_file: str, start_map: np.array, len_map: np.array, | ||
max_tokens: int, batch_size = 10000): | ||
# async version of the above - used to overlap reads and GPU computation | ||
def gen(): | ||
rng = np.random.Generator(np.random.PCG64()) | ||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: | ||
future_batch = executor.submit(sample_from_vec, prob_vector, batch_size, ctx_len, | ||
np.memmap(memmaped_file, dtype='int32', mode='r', shape=(max_tokens)), | ||
start_map, len_map, rng) | ||
|
||
while True: | ||
start = time.time() | ||
# Wait for the future to complete and get the result | ||
sampled_batches, masks = future_batch.result() | ||
|
||
# Submit the next batch generation | ||
future_batch = executor.submit(sample_from_vec, prob_vector, batch_size, ctx_len, | ||
np.memmap(memmaped_file, dtype='int32', mode='r', shape=(max_tokens)), | ||
start_map, len_map, rng) | ||
|
||
end = time.time() | ||
print('batch overhead '+str(end-start)+'(s)') | ||
for i in range(batch_size): | ||
yield { | ||
"input_ids": sampled_batches[i,:].squeeze(), | ||
"labels": sampled_batches[i,:].squeeze(), | ||
"attention_mask": masks[i,:].squeeze() | ||
} | ||
|
||
print('get_dataset') | ||
return IterableDataset.from_generator(gen) | ||
|
||
# Plan for without replacement sampler: | ||
# Do modulo rank | ||
# Do modulo seq len | ||
# Convert prob vector into token counts | ||
# Store a dict of the remaining token counts to sample for each page | ||
# Once the remaining token counts gets to zero, remove it from the dict | ||
# If the remaining token count dict is empty, start again. This means we are doing more than one epoch. | ||
# If the remaining token counts is smaller than seq len, keep sampling from other pages, adding an eos token inbetween, until you are at seq len. | ||
|
||
|
||
|
||
if __name__ == '__main__': | ||
model_name = "gpt2" | ||
tokenizer = AutoTokenizer.from_pretrained(model_name) | ||
print('Tokenizer fastness:'+str(tokenizer.is_fast)) | ||
max_tokens = 1024 | ||
test_strings = ["the quick brown fox jumps over the lazy dog", "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum."] | ||
print(tokenizer(test_strings[0], return_tensors="pt")) | ||
tokenized_list = [tokenizer(seq+tokenizer.eos_token, return_tensors="pt")['input_ids'][0] for seq in test_strings] | ||
merged_seq, start_map, len_map, id_list = tokenize_and_mmap(enumerate(tokenized_list), tokenizer, max_tokens,4, 'test') | ||
dataset = get_dataset_async(np.array([0.1, 0.9])[id_list], 4, 'test.mmap', start_map, len_map, max_tokens) | ||
for i, data in enumerate(dataset): | ||
print(tokenizer.decode(data['input_ids'].numpy().tolist())) | ||
#_ = tokenizer.decode(data[0]) | ||
if i > 100: | ||
break |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
peft | ||
wandb |
16 changes: 16 additions & 0 deletions
16
examples/pretrain_llm/sample_weights_configs/yangjun_synthetic.yml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
file_prefix: yangjun_synthetic/data | ||
desired_token_count: 320000000 | ||
output_dir: sample_weights | ||
targets: | ||
- output_name: sciq_yangjun_synthetic.npy | ||
fasttext_model_path: ../get_fasttext_filter/fasttext_models/sciq_yangjun_synthetic.bin | ||
- output_name: piqa_yangjun_synthetic.npy | ||
fasttext_model_path: ../get_fasttext_filter/fasttext_models/piqa_yangjun_synthetic.bin | ||
- output_name: arc_easy_yangjun_synthetic.npy | ||
fasttext_model_path: ../get_fasttext_filter/fasttext_models/arc_easy_yangjun_synthetic.bin | ||
- output_name: lambada_yangjun_synthetic.npy | ||
fasttext_model_path: ../get_fasttext_filter/fasttext_models/lambada_yangjun_synthetic.bin | ||
id_column: url | ||
split: train | ||
text_column: latent | ||
hf_dataset: ../get_error_and_bpb/pre_chunked_datasets/yangjun_synthetic |
Oops, something went wrong.