-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain_srcnn.py
77 lines (61 loc) · 2.21 KB
/
train_srcnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
# -*- coding: utf-8 -*-
"""
Created on Sun Mar 26 10:50:48 2017
@author: galad-loth
"""
#import numpy as npy
import logging
import sys
import mxnet as mx
from symbols.srcnn_symbol import srcnn_symbol
from utils.dataio import get_sr_iter
from utils.evaluate_metric import psnr
logging.basicConfig(level=logging.INFO)
root_logger = logging.getLogger()
stdout_handler = logging.StreamHandler(sys.stdout)
root_logger.addHandler(stdout_handler)
root_logger.setLevel(logging.INFO)
net=srcnn_symbol(64,32,3)
mod = mx.mod.Module(symbol=net,
context=mx.gpu(),
data_names=['imgin'],
label_names=['loss_imghr'])
#optimizer = mx.optimizer.create(
# 'sgd',
# learning_rate =0.000001,
# momentum = 0.9,
# wd=0.002,
# lr_scheduler=mx.lr_scheduler.FactorScheduler(9000,0.9))
optimizer = mx.optimizer.create(
'adagrad',
learning_rate =0.0005,
wd=0.005,
clip_gradient=0.001,
lr_scheduler=mx.lr_scheduler.FactorScheduler(5000,0.6))
lr_scale={}
for arg_name in net.list_arguments():
if "conv0" in arg_name or "conv1" in arg_name:
lr_scale[arg_name] = 5
optimizer.set_lr_mult(lr_scale)
initializer = mx.init.Xavier(rnd_type='gaussian',
factor_type="in",
magnitude=2)
model_prefix="checkpoint\\srcnn"
checkpoint = mx.callback.do_checkpoint(model_prefix,period=100)
datadir="E:\\DevProj\\Datasets\\SuperResolution\\SRCNN_Train"
batch_size=100
data_params={"batch_size":batch_size,"crop_size":33,"scale_factor":2,
"is_train":True,"num_train_img":90,"num_val_img":0,
"img_type":[".jpg",".bmp"]}
train_iter, _=get_sr_iter(datadir,data_params)
datadir1="E:\\DevProj\\Datasets\\SuperResolution\\Set14"
data_params["is_train"]=False
val_iter=get_sr_iter(datadir1,data_params)
mod.fit(train_iter,
num_epoch=1000,
eval_data=val_iter,
eval_metric=psnr,
optimizer=optimizer,
initializer=initializer,
batch_end_callback = mx.callback.Speedometer(batch_size, 2000),
epoch_end_callback=checkpoint)