-
Notifications
You must be signed in to change notification settings - Fork 85
/
Copy pathutils.py
115 lines (90 loc) · 4 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import time
import os
from datasets.block import BlockDataset, LatentBlockDataset
import numpy as np
def load_cifar():
train = datasets.CIFAR10(root="data", train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(
(0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
]))
val = datasets.CIFAR10(root="data", train=False, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(
(0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
]))
return train, val
def load_block():
data_folder_path = os.getcwd()
data_file_path = data_folder_path + \
'/data/randact_traj_length_100_n_trials_1000_n_contexts_1.npy'
train = BlockDataset(data_file_path, train=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(
(0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
]))
val = BlockDataset(data_file_path, train=False,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(
(0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
]))
return train, val
def load_latent_block():
data_folder_path = os.getcwd()
data_file_path = data_folder_path + \
'/data/latent_e_indices.npy'
train = LatentBlockDataset(data_file_path, train=True,
transform=None)
val = LatentBlockDataset(data_file_path, train=False,
transform=None)
return train, val
def data_loaders(train_data, val_data, batch_size):
train_loader = DataLoader(train_data,
batch_size=batch_size,
shuffle=True,
pin_memory=True)
val_loader = DataLoader(val_data,
batch_size=batch_size,
shuffle=True,
pin_memory=True)
return train_loader, val_loader
def load_data_and_data_loaders(dataset, batch_size):
if dataset == 'CIFAR10':
training_data, validation_data = load_cifar()
training_loader, validation_loader = data_loaders(
training_data, validation_data, batch_size)
x_train_var = np.var(training_data.train_data / 255.0)
elif dataset == 'BLOCK':
training_data, validation_data = load_block()
training_loader, validation_loader = data_loaders(
training_data, validation_data, batch_size)
x_train_var = np.var(training_data.data / 255.0)
elif dataset == 'LATENT_BLOCK':
training_data, validation_data = load_latent_block()
training_loader, validation_loader = data_loaders(
training_data, validation_data, batch_size)
x_train_var = np.var(training_data.data)
else:
raise ValueError(
'Invalid dataset: only CIFAR10 and BLOCK datasets are supported.')
return training_data, validation_data, training_loader, validation_loader, x_train_var
def readable_timestamp():
return time.ctime().replace(' ', ' ').replace(
' ', '_').replace(':', '_').lower()
def save_model_and_results(model, results, hyperparameters, timestamp):
SAVE_MODEL_PATH = os.getcwd() + '/results'
results_to_save = {
'model': model.state_dict(),
'results': results,
'hyperparameters': hyperparameters
}
torch.save(results_to_save,
SAVE_MODEL_PATH + '/vqvae_data_' + timestamp + '.pth')