-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain_MNIST_unsup.py
113 lines (89 loc) · 4.34 KB
/
train_MNIST_unsup.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import numpy as np
import matplotlib.pyplot as plt
import os
import torch
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
from torchsummary import summary
from Train.train_CNP_images import train_CNP_unsup
from CNPs.create_model import create_model
from Utils.data_loader import load_data_unsupervised
from Utils.helper_results import qualitative_evaluation_images, plot_loss
"""
def plot_loss(loss_dir_txt,loss_dir_plot):
loss = []
with open(loss_dir_txt,"r") as f:
for x in f.read().split():
if x != "":
loss.append(int(float(x)))
# plot
plt.figure()
plt.plot(np.arange(1,len(loss)+1),loss)
plt.xlabel("Epoch",fontsize=15)
plt.ylabel("Negative log-likelihood",fontsize=15)
plt.savefig(loss_dir_plot)
"""
if __name__ == "__main__":
# use GPU if available
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# type of model
model_name = "UNet_restrained" # one of ["CNP", "ConvCNP", "ConvCNPXL", "UNetCNP", "UNet_restrained"]
train = True
load = False
save = True
if load:
epoch_start = 400 # which epoch to start from
else:
epoch_start = 0
save_freq = 50 # epoch frequency of saving checkpoints
semantics = True
# parameters
batch_size = 64
validation_split = 0.10
learning_rate = 1e-4
if train:
epochs = 400
else:
epochs = 0
# create the models
model, convolutional = create_model(model_name)
model.to(device)
# print a summary of the model
if convolutional:
summary(model,[(1,28,28),(1,28,28)])
else:
summary(model, [(50, 2), (50, 1), (784,2)])
# load the MNIST data
train_data, valid_data, test_data = load_data_unsupervised(batch_size,validation_split=validation_split)
# directories
model_save_dir = ["saved_models/MNIST/", model_name + ("_semantics" if semantics else ""), "/",model_name + ("_semantics" if semantics else ""),"_","","E",".pth"]
train_loss_dir_txt = "saved_models/MNIST/" + model_name + ("_semantics" if semantics else "") + "/loss/train_" + model_name + ("_semantics" if semantics else "") + ".txt"
validation_loss_dir_txt = "saved_models/MNIST/" + model_name + ("_semantics" if semantics else "") + "/loss/validation_" + model_name + ("_semantics" if semantics else "") + ".txt"
loss_dir_plot = "saved_models/MNIST/" + model_name + ("_semantics" if semantics else "") + "/loss/" + model_name + ("_semantics" if semantics else "") + ".svg"
visualisation_dir = ["saved_models/MNIST/", model_name + ("_semantics" if semantics else ""), "/visualisation/",model_name + ("_semantics" if semantics else ""),"_","","E_","","C.svg"]
# create directories for the checkpoints and loss files if they don't exist yet
dir_to_create = "".join(model_save_dir[:3]) + "loss/"
os.makedirs(dir_to_create, exist_ok=True)
if load:
load_dir = model_save_dir.copy()
load_dir[5] = str(epoch_start)
load_dir = "".join(load_dir)
# check if the loss file is valid
with open(train_loss_dir_txt, 'r') as f:
nbr_losses = len(f.read().split())
assert nbr_losses == epoch_start, "The number of lines in the loss file does not correspond to the number of epochs"
# load the model
model.load_state_dict(torch.load(load_dir,map_location=device))
else:
# if train from scratch, check if a loss file already exists
if train or save:
assert not(os.path.isfile(train_loss_dir_txt)), "The corresponding loss file already exists, please remove it to train from scratch: " + train_loss_dir_txt
if train:
avg_loss_per_epoch = train_CNP_unsup(train_data, model, epochs, model_save_dir, train_loss_dir_txt, semantics=semantics, validation_data=valid_data, validation_loss_dir_txt=validation_loss_dir_txt, convolutional=convolutional, visualisation_dir=visualisation_dir, save_freq=save_freq, epoch_start=epoch_start, device=device, learning_rate=learning_rate)
plot_loss([train_loss_dir_txt,validation_loss_dir_txt],loss_dir_plot)
if save:
save_dir = model_save_dir.copy()
save_dir[5] = str(epoch_start + epochs)
save_dir = "".join(save_dir)
torch.save(model.state_dict(),save_dir)