forked from KindofCrazy/Music-Genre-Classification
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtools.py
147 lines (128 loc) · 5.22 KB
/
tools.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import torch
from torch import nn
import torchvision
from torch.utils import data
from torchvision import transforms
import visdom
import my_models
def weight_init(m):
if type(m) == nn.Linear:
nn.init.xavier_normal_(m.weight)
#nn.init.normal_(m.weight,std=0.01)
elif type(m) == nn.Conv2d:
nn.init.xavier_normal_(m.weight)
class accumulator: #轮子,用于记录训练中的loss和accuracy以便于可视化
def __init__(self,n):
self.data = [0.0]*n
def add(self,args):
self.data = [a + float(b) for a,b in zip(self.data,args)]
def reset(self):
self.data = [0.0]*len(self.data)
def __getitem__(self,idx):
return self.data[idx]
class visualize(object): #利用visdom实现loss,accuracy的可视化监控
def __init__(self):
self.vis = visdom.Visdom()
self.vis.line(
X=[0.],
Y=[[0.,0.,0.]],
win='classifier',
env='module_1',
opts=dict(title = 'classifier',legend=['train_loss','train_accuracy','test_accuracy']))
def paint(self,train_loss,test_accuracy,train_accuracy,epochs):
self.vis.line(
X=[epochs],
Y=[[train_loss,test_accuracy,train_accuracy]],
win='classifier',
update='append',
opts=dict(legend=['train_loss','train_accuracy','test_accuracy']))
def data_iter(train_data,test_data,batch_size=50): #输入张量
return (data.DataLoader(
train_data,batch_size=batch_size,shuffle=True,num_workers=4,prefetch_factor=2
),
data.DataLoader(
test_data,batch_size=batch_size,shuffle=False,num_workers=4,prefetch_factor=2
))
def loss():
return torch.nn.CrossEntropyLoss()
def optimize(model,lr = 0.001,weight_decay = 0):
return torch.optim.SGD(model.parameters(),lr=lr,momentum=0.9)
#return torch.optim.Adam(model.parameters(),lr=lr)
#opt = torch.optim.AdamW(model.parameters(),lr=lr,weight_decay=weight_decay)
#return torch.optim.lr_scheduler.StepLR(optimizer = opt,step_size=step_size,gamma=gamma)
def train_epoch(X,y,net,loss,optimizer,device): # return the loss and accuracy
optimizer.zero_grad()
X,y = X.to(device),y.to(device)
y_hat = net(X)
l = loss(y_hat,y)
l.backward()
optimizer.step()
with torch.no_grad():
return l*X.shape[0],accuracy(y_hat,y),X.shape[0]
def train(net,train_iter,test_iter,num_epochs,loss,optimizer,device): #visualise的初始化放在里面
net.to(device)
visualizer = visualize()
scheduler = torch.optim.lr_scheduler.StepLR(optimizer = optimizer,step_size=35,gamma=0.65)
for epoch in range(num_epochs):
metric = accumulator(3)
for X,y in train_iter:
metric.add(train_epoch(X,y,net,loss,optimizer,device))
scheduler.step()
train_l = metric[0] / metric[2]
train_acc = metric[1] / metric[2]
test_acc = accuracy_test(net,test_iter)
visualizer.paint(train_l,test_acc,train_acc,epoch)
def accuracy(y_hat,y):
if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:
y_hat = y_hat.argmax(axis=1)
cmp = y_hat.type(y.dtype) == y
return float(cmp.type(y.dtype).sum())
def accuracy_test(net,data_iter,device=None):
"""使用GPU计算模型在数据集上的精度"""
if isinstance(net, nn.Module):
net.eval() # 设置为评估模式
if not device:
device = next(iter(net.parameters())).device
# 正确预测的数量,总预测的数量
metric = accumulator(2)
with torch.no_grad():
for X, y in data_iter:
if isinstance(X, list):
X = [x.to(device) for x in X]
else:
X = X.to(device)
y = y.to(device)
metric.add((accuracy(net(X), y), y.numel()))
return metric[0] / metric[1]
def save_model(net,path = r"..\models\alpha.pth"): #保存为.pth文件
torch.save(net,path)
def load_model(path = r"..\models\beta 96.1%.pth"):
model = torch.load(path)
model.to('cuda')
model.eval()
return model
'''下面是测试用代码'''
def load_data_fashion_mnist(batch_size, resize=None): #@save
"""下载Fashion-MNIST数据集,然后将其加载到内存中"""
trans = [transforms.ToTensor()]
if resize:
trans.insert(0, transforms.Resize(resize))
trans = transforms.Compose(trans)
mnist_train = torchvision.datasets.FashionMNIST(
root="../data", train=True, transform=trans, download=False)
mnist_test = torchvision.datasets.FashionMNIST(
root="../data", train=False, transform=trans, download=False)
return (data.DataLoader(mnist_train, batch_size, shuffle=True,
num_workers=1),
data.DataLoader(mnist_test, batch_size, shuffle=False,
num_workers=1))
if __name__ == '__main__':
lr,num_epochs,batch_size = 0.001,10,50
train_iter,test_iter = load_data_fashion_mnist(batch_size,resize = 448)
net = my_models.vgg(ratio=4,size_x=448,size_y=448)
net.apply(weight_init)
train(net,train_iter,test_iter,num_epochs,loss(),optimize(net,lr),'cuda')
'''X = torch.randn(size=(1,1,96,96))
for blk in net:
X = blk(X)
print(blk.__class__.__name__,'output shape:/t',X.shape)'''