-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathDAE_model.py
52 lines (43 loc) · 2.13 KB
/
DAE_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import tensorflow as tf
import numpy as np
import scipy.io as io
def get_variable(params, name):
init = tf.constant_initializer(params, dtype=tf.float32)
var = tf.get_variable(name=name, initializer=init, shape=params.shape)
return var
def conv2d_basic(x, W, bias):
conv = tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding="SAME")
return tf.nn.bias_add(conv, bias)
def network(weights, image):
net = {}
current = image
for layer_ind in range(weights['net'].shape[1]):
layer_name = 'layer' + str(layer_ind)
if layer_ind % 2 == 0:
kernels = np.float32(weights['net'][0, layer_ind]['weights'][0, 0][0, 0])
bias = np.float32(weights['net'][0, layer_ind]['weights'][0, 0][0, 1])
# matconvnet: weights are [width, height, in_channels, out_channels]
# tensorflow: weights are [height, width, in_channels, out_channels]
kernels = get_variable(np.transpose(kernels, (1, 0, 2, 3)), name=layer_name + "_w")
bias = get_variable(bias.reshape(-1), name=layer_name + "_b")
current = conv2d_basic(current, kernels, bias)
else:
current = tf.nn.relu(current, name=layer_name)
net[layer_name] = current
return net
class denoiser(object):
""" Implements DAE objects with neural net parameters and functions. """
def __init__(self, sess):
self.sess = sess
self.in_image = tf.placeholder(tf.float32, shape=[None, None, None, 3], name="input_image")
image_bgr = self.in_image[..., ::-1]
weights = io.loadmat('DAE_sigma11.mat')
with tf.variable_scope("dae", reuse=None):
dae_net = network(weights=weights, image=image_bgr)
output_bgr = image_bgr + dae_net['layer' + str(weights['net'].shape[1] - 1)]
self.output = output_bgr[..., ::-1]
init = tf.global_variables_initializer()
self.sess.run(init)
def denoise(self, noisy):
""" Implements the network forward pass to denoise an BGR image (in the range of 0 to 255) """
return np.squeeze(self.sess.run(self.output, feed_dict={self.in_image: np.expand_dims(noisy, axis=0)}))