-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathlayout.py
119 lines (90 loc) · 4.12 KB
/
layout.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import argparse
import csv
import numpy as np
import tqdm
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
from torch.utils.data.sampler import SubsetRandomSampler
from UI_embedding.plotter import plot_loss
from autoencoder import ScreenLayoutDataset, LayoutAutoEncoder, LayoutTrainer
from autoencoder import ScreenVisualLayout, ScreenVisualLayoutDataset, ImageAutoEncoder, ImageTrainer
# file that runs training of the layout autoencoder
parser = argparse.ArgumentParser()
parser.add_argument("-d", "--dataset", required=True, type=str, help="dataset of screens to train on")
parser.add_argument("-b", "--batch_size", type=int, default=64, help="traces in a batch")
parser.add_argument("-e", "--epochs", type=int, default=10, help="number of epochs")
parser.add_argument("-r", "--rate", type=float, default=0.001, help="learning rate")
parser.add_argument("-t", "--type", type=int, default=0, help="0 to create layout autoencoder, 1 to create visual autoencoder")
args = parser.parse_args()
if args.type == 0:
dataset = ScreenLayoutDataset(args.dataset)
dataset_size = len(dataset)
indices = list(range(dataset_size))
split = int(np.floor(0.1 * dataset_size))
np.random.shuffle(indices)
train_indices, val_indices = indices[split:], indices[:split]
# Creating PT data samplers and loaders:
train_sampler = SubsetRandomSampler(train_indices)
test_sampler = SubsetRandomSampler(val_indices)
train_loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, sampler=train_sampler)
test_loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, sampler=test_sampler)
model = LayoutAutoEncoder()
model.cuda()
trainer = LayoutTrainer(model, train_loader, test_loader, args.rate)
train_loss_data = []
test_loss_data = []
for epoch in tqdm.tqdm(range(args.epochs)):
print("--------")
print(str(epoch) + " loss:")
train_loss = trainer.train(epoch)
print(train_loss)
print("--------")
train_loss_data.append(train_loss)
test_loss = trainer.test(epoch)
test_loss_data.append(test_loss)
print(test_loss)
print("--------")
if (epoch%50)==0:
print("saved on epoch " + str(epoch))
trainer.save(epoch)
plot_loss(train_loss_data, test_loss_data, "output/autoencoder")
trainer.save(args.epochs, "output/autoencoder")
elif args.type == 1:
dataset = ScreenVisualLayoutDataset(args.dataset)
dataset_size = len(dataset)
indices = list(range(dataset_size))
split = int(np.floor(0.1 * dataset_size))
np.random.shuffle(indices)
train_indices, val_indices = indices[split:], indices[:split]
# Creating PT data samplers and loaders:
train_sampler = SubsetRandomSampler(train_indices)
test_sampler = SubsetRandomSampler(val_indices)
train_loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, sampler=train_sampler)
test_loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, sampler=test_sampler)
model = ImageAutoEncoder()
model.cuda()
trainer = ImageTrainer(model, train_loader, test_loader, args.rate)
train_loss_data = []
test_loss_data = []
for epoch in tqdm.tqdm(range(args.epochs)):
print("--------")
print(str(epoch) + " loss:")
train_loss = trainer.train(epoch)
print(train_loss)
print("--------")
train_loss_data.append(train_loss)
test_loss = trainer.test(epoch)
test_loss_data.append(test_loss)
print(test_loss)
print("--------")
if (epoch%50)==0:
print("saved on epoch " + str(epoch))
trainer.save(epoch, "output/visual_encoder_fast")
plot_loss(train_loss_data, test_loss_data, "output/visual_encoder_fast")
trainer.save(args.epochs, "output/visual_encoder_fast")
with open("output/visual_encoder_fast.csv", 'w', newline='') as myfile:
wr = csv.writer(myfile)
for row in range(len(train_loss_data)):
wr.writerow([train_loss_data[row], test_loss_data[row]])