-
Notifications
You must be signed in to change notification settings - Fork 537
/
Copy pathmain_train.py
180 lines (146 loc) · 7.15 KB
/
main_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
# -*- coding: utf-8 -*-
# PyTorch 0.4.1, https://pytorch.org/docs/stable/index.html
# =============================================================================
# @article{zhang2017beyond,
# title={Beyond a {Gaussian} denoiser: Residual learning of deep {CNN} for image denoising},
# author={Zhang, Kai and Zuo, Wangmeng and Chen, Yunjin and Meng, Deyu and Zhang, Lei},
# journal={IEEE Transactions on Image Processing},
# year={2017},
# volume={26},
# number={7},
# pages={3142-3155},
# }
# by Kai Zhang (08/2018)
# https://github.com/cszn
# modified on the code from https://github.com/SaoYan/DnCNN-PyTorch
# =============================================================================
# run this to train the model
# =============================================================================
# For batch normalization layer, momentum should be a value from [0.1, 1] rather than the default 0.1.
# The Gaussian noise output helps to stablize the batch normalization, thus a large momentum (e.g., 0.95) is preferred.
# =============================================================================
import argparse
import re
import os, glob, datetime, time
import numpy as np
import torch
import torch.nn as nn
from torch.nn.modules.loss import _Loss
import torch.nn.init as init
from torch.utils.data import DataLoader
import torch.optim as optim
from torch.optim.lr_scheduler import MultiStepLR
import data_generator as dg
from data_generator import DenoisingDataset
# Params
parser = argparse.ArgumentParser(description='PyTorch DnCNN')
parser.add_argument('--model', default='DnCNN', type=str, help='choose a type of model')
parser.add_argument('--batch_size', default=128, type=int, help='batch size')
parser.add_argument('--train_data', default='data/Train400', type=str, help='path of train data')
parser.add_argument('--sigma', default=25, type=int, help='noise level')
parser.add_argument('--epoch', default=180, type=int, help='number of train epoches')
parser.add_argument('--lr', default=1e-3, type=float, help='initial learning rate for Adam')
args = parser.parse_args()
batch_size = args.batch_size
cuda = torch.cuda.is_available()
n_epoch = args.epoch
sigma = args.sigma
save_dir = os.path.join('models', args.model+'_' + 'sigma' + str(sigma))
if not os.path.exists(save_dir):
os.mkdir(save_dir)
class DnCNN(nn.Module):
def __init__(self, depth=17, n_channels=64, image_channels=1, use_bnorm=True, kernel_size=3):
super(DnCNN, self).__init__()
kernel_size = 3
padding = 1
layers = []
layers.append(nn.Conv2d(in_channels=image_channels, out_channels=n_channels, kernel_size=kernel_size, padding=padding, bias=True))
layers.append(nn.ReLU(inplace=True))
for _ in range(depth-2):
layers.append(nn.Conv2d(in_channels=n_channels, out_channels=n_channels, kernel_size=kernel_size, padding=padding, bias=False))
layers.append(nn.BatchNorm2d(n_channels, eps=0.0001, momentum = 0.95))
layers.append(nn.ReLU(inplace=True))
layers.append(nn.Conv2d(in_channels=n_channels, out_channels=image_channels, kernel_size=kernel_size, padding=padding, bias=False))
self.dncnn = nn.Sequential(*layers)
self._initialize_weights()
def forward(self, x):
y = x
out = self.dncnn(x)
return y-out
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
init.orthogonal_(m.weight)
print('init weight')
if m.bias is not None:
init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
init.constant_(m.weight, 1)
init.constant_(m.bias, 0)
class sum_squared_error(_Loss): # PyTorch 0.4.1
"""
Definition: sum_squared_error = 1/2 * nn.MSELoss(reduction = 'sum')
The backward is defined as: input-target
"""
def __init__(self, size_average=None, reduce=None, reduction='sum'):
super(sum_squared_error, self).__init__(size_average, reduce, reduction)
def forward(self, input, target):
# return torch.sum(torch.pow(input-target,2), (0,1,2,3)).div_(2)
return torch.nn.functional.mse_loss(input, target, size_average=None, reduce=None, reduction='sum').div_(2)
def findLastCheckpoint(save_dir):
file_list = glob.glob(os.path.join(save_dir, 'model_*.pth'))
if file_list:
epochs_exist = []
for file_ in file_list:
result = re.findall(".*model_(.*).pth.*", file_)
epochs_exist.append(int(result[0]))
initial_epoch = max(epochs_exist)
else:
initial_epoch = 0
return initial_epoch
def log(*args, **kwargs):
print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S:"), *args, **kwargs)
if __name__ == '__main__':
# model selection
print('===> Building model')
model = DnCNN()
initial_epoch = findLastCheckpoint(save_dir=save_dir) # load the last model in matconvnet style
if initial_epoch > 0:
print('resuming by loading epoch %03d' % initial_epoch)
# model.load_state_dict(torch.load(os.path.join(save_dir, 'model_%03d.pth' % initial_epoch)))
model = torch.load(os.path.join(save_dir, 'model_%03d.pth' % initial_epoch))
model.train()
# criterion = nn.MSELoss(reduction = 'sum') # PyTorch 0.4.1
criterion = sum_squared_error()
if cuda:
model = model.cuda()
# device_ids = [0]
# model = nn.DataParallel(model, device_ids=device_ids).cuda()
# criterion = criterion.cuda()
optimizer = optim.Adam(model.parameters(), lr=args.lr)
scheduler = MultiStepLR(optimizer, milestones=[30, 60, 90], gamma=0.2) # learning rates
for epoch in range(initial_epoch, n_epoch):
scheduler.step(epoch) # step to the learning rate in this epcoh
xs = dg.datagenerator(data_dir=args.train_data)
xs = xs.astype('float32')/255.0
xs = torch.from_numpy(xs.transpose((0, 3, 1, 2))) # tensor of the clean patches, NXCXHXW
DDataset = DenoisingDataset(xs, sigma)
DLoader = DataLoader(dataset=DDataset, num_workers=4, drop_last=True, batch_size=batch_size, shuffle=True)
epoch_loss = 0
start_time = time.time()
for n_count, batch_yx in enumerate(DLoader):
optimizer.zero_grad()
if cuda:
batch_x, batch_y = batch_yx[1].cuda(), batch_yx[0].cuda()
loss = criterion(model(batch_y), batch_x)
epoch_loss += loss.item()
loss.backward()
optimizer.step()
if n_count % 10 == 0:
print('%4d %4d / %4d loss = %2.4f' % (epoch+1, n_count, xs.size(0)//batch_size, loss.item()/batch_size))
elapsed_time = time.time() - start_time
log('epcoh = %4d , loss = %4.4f , time = %4.2f s' % (epoch+1, epoch_loss/n_count, elapsed_time))
np.savetxt('train_result.txt', np.hstack((epoch+1, epoch_loss/n_count, elapsed_time)), fmt='%2.4f')
# torch.save(model.state_dict(), os.path.join(save_dir, 'model_%03d.pth' % (epoch+1)))
torch.save(model, os.path.join(save_dir, 'model_%03d.pth' % (epoch+1)))