Skip to content

Commit

Permalink
Replace fire with argparse
Browse files Browse the repository at this point in the history
  • Loading branch information
gordicaleksa committed Aug 10, 2024
1 parent 311c149 commit eb1b359
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 18 deletions.
6 changes: 3 additions & 3 deletions dev/data/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

The idea is that each dataset has a .py file here in the root of `dev/data`, and each dataset then creates a directory here, and writes and caches anything inside that directory. So for example:

- running `python tinystories.py --model gpt-2` will create a directory `tinystories` with its .bin files inside it
- running `python tinyshakespeare.py -- model gpt-2` will create a directory `tinyshakespeare` with its .bin files inside it
- running `python tinystories.py` will create a directory `tinystories` with its .bin files inside it
- running `python tinyshakespeare.py` will create a directory `tinyshakespeare` with its .bin files inside it

And so on. This way we can nicely organize multiple datasets here, share common utilities between them, and then point the .py/.c code in the root of the project accordingly to these.

Note: `--model` can be either "gpt-2" or "llama" (we assume llama 3).
Note: we support "gpt-2" and "llama" (llama 3 in particular) models and the above scripts will tokenize gpt-2 by default.
13 changes: 6 additions & 7 deletions dev/data/tinyshakespeare.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
streams of uint16 (gpt-2) or uint32 (llama) numbers indicating the token ids.
"""

import argparse
import os

import fire
import tiktoken
from transformers import AutoTokenizer

Expand Down Expand Up @@ -69,10 +69,9 @@ def encode(x):
write_datafile(val_filename, val_tokens, model)
write_datafile(train_filename, train_tokens, model)

def process(model):
assert model in ["gpt-2", "llama"], f"unknown model {model} (choose from gpt-2, llama)"
download()
tokenize(model)

if __name__ == "__main__":
fire.Fire(process)
parser = argparse.ArgumentParser(description="Tiny Shakespeare dataset preprocessing")
parser.add_argument("-m", "--model", type=str, default="gpt-2", choices=["gpt-2", "llama"], help="Model type, gpt-2|llama")
args = parser.parse_args()
download()
tokenize(args.model)
13 changes: 6 additions & 7 deletions dev/data/tinystories.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@
streams of uint16 (gpt-2) or uint32 (llama) numbers indicating the token ids.
"""

import argparse
import os
import glob
import json
import random
from concurrent.futures import ProcessPoolExecutor, as_completed

import fire
import tiktoken
from transformers import AutoTokenizer

Expand Down Expand Up @@ -116,10 +116,9 @@ def tokenize(model):
split_filename = os.path.join(DATA_CACHE_DIR, f"TinyStories_{split_name}.bin")
write_datafile(split_filename, all_tokens, model)

def process(model):
assert model in ["gpt-2", "llama"], f"unknown model {model} (choose from gpt-2, llama)"
download()
tokenize(model)

if __name__ == "__main__":
fire.Fire(process)
parser = argparse.ArgumentParser(description="Tiny Stories dataset preprocessing")
parser.add_argument("-m", "--model", type=str, default="gpt-2", choices=["gpt-2", "llama"], help="Model type, gpt-2|llama")
args = parser.parse_args()
download()
tokenize(args.model)
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
fire
tqdm
numpy<2
torch
Expand Down

0 comments on commit eb1b359

Please sign in to comment.