-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathutils.py
91 lines (74 loc) · 2.64 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import os, gzip, torch
import torch.nn as nn
import numpy as np
import scipy.misc
import imageio
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
import torch.nn.functional as F
from torch.autograd import Variable
def print_network(net):
num_params = 0
for param in net.parameters():
num_params += param.numel()
print(net)
print('Total number of parameters: %d' % num_params)
def save_images(images, size, image_path):
return imsave(images, size, image_path)
def imsave(images, size, path):
image = np.squeeze(merge(images, size))
return scipy.misc.imsave(path, image)
def merge(images, size):
h, w = images.shape[1], images.shape[2]
if (images.shape[3] in (3,4)):
c = images.shape[3]
img = np.zeros((h * size[0], w * size[1], c))
for idx, image in enumerate(images):
i = idx % size[1]
j = idx // size[1]
img[j * h:j * h + h, i * w:i * w + w, :] = image
return img
elif images.shape[3]==1:
img = np.zeros((h * size[0], w * size[1]))
for idx, image in enumerate(images):
i = idx % size[1]
j = idx // size[1]
img[j * h:j * h + h, i * w:i * w + w] = image[:,:,0]
return img
else:
raise ValueError('in merge(images,size) images parameter ''must have dimensions: HxW or HxWx3 or HxWx4')
def generate_animation(path, num):
images = []
for e in range(num):
img_name = path + '_epoch%03d' % (e+1) + '.png'
images.append(imageio.imread(img_name))
imageio.mimsave(path + '_generate_animation.gif', images, fps=5)
def loss_plot(hist, path = 'Train_hist.png', model_name = ''):
x = range(len(hist['D_loss']))
y1 = hist['D_loss']
y2 = hist['G_loss']
y3 = hist['C_loss']
y4 = hist['E_loss']
plt.plot(x, y1, label='D_loss')
plt.plot(x, y2, label='G_loss')
plt.plot(x, y3, label='C_loss')
plt.plot(x, y4, label='E_loss')
plt.xlabel('Iter')
plt.ylabel('Loss')
plt.legend(loc=4)
plt.grid(True)
plt.tight_layout()
path = os.path.join(path, model_name + '_loss.png')
plt.savefig(path)
plt.close()
def initialize_weights(net):
for m in net.modules():
if isinstance(m, nn.Conv2d):
m.weight.data.normal_(0, 0.02)
m.bias.data.zero_()
elif isinstance(m, nn.ConvTranspose2d):
m.weight.data.normal_(0, 0.02)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
m.weight.data.normal_(0, 0.02)
m.bias.data.zero_()