diff --git a/train.py b/train.py index 30d0145ef7..0940ef3aab 100644 --- a/train.py +++ b/train.py @@ -45,6 +45,8 @@ wandb_run_name = 'gpt2' # 'run' + str(time.time()) # data dataset = 'openwebtext' +train_data_path = '' # explicit path to train dataset instead of dataset name convension +val_data_path = '' # explicit path to valiidation dataset instead of dataset name convension gradient_accumulation_steps = 5 # used to simulate larger batch sizes batch_size = 12 # if gradient_accumulation_steps > 1, this is the micro-batch size block_size = 1024 @@ -74,7 +76,10 @@ compile = True # use PyTorch 2.0 to compile the model to be faster # ----------------------------------------------------------------------------- config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))] -exec(open('configurator.py').read()) # overrides from command line or config file + +dir_path = os.path.dirname(os.path.realpath(__file__)) +configurator = os.path.join(dir_path, 'configurator.py') +exec(open(configurator).read()) # overrides from command line or config file config = {k: globals()[k] for k in config_keys} # will be useful for logging # ----------------------------------------------------------------------------- @@ -106,6 +111,13 @@ # poor man's data loader data_dir = os.path.join('data', dataset) +if not train_data_path: + train_data_path = os.path.join(data_dir, 'train.bin') +if not val_data_path: + val_data_path = os.path.join(data_dir, 'val.bin') +train_data = np.memmap(train_data_path, dtype=np.uint16, mode='r') +val_data = np.memmap(val_data_path, dtype=np.uint16, mode='r') + train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r') val_data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r') def get_batch(split):