-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathDMSPDeblur.py
127 lines (100 loc) · 4.92 KB
/
DMSPDeblur.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import time
import numpy as np
import scipy.signal as sig
import skimage
def computePSNR(img1, img2, pad_y, pad_x):
""" Computes peak signal-to-noise ratio between two images.
Input:
img1: First image in range of [0, 255].
img2: Second image in range of [0, 255].
pad_y: Scalar radius to exclude boundaries from contributing to PSNR computation in vertical direction.
pad_x: Scalar radius to exclude boundaries from contributing to PSNR computation in horizontal direction.
Output: PSNR """
img1_u = (np.clip(np.squeeze(img1), 0, 255.0)[pad_y:-pad_y,pad_x:-pad_x,:]).astype(dtype=np.uint8)
img2_u = (np.clip(np.squeeze(img2), 0, 255.0)[pad_y:-pad_y,pad_x:-pad_x,:]).astype(dtype=np.uint8)
imdiff = (img1_u).astype(dtype=np.float32) - (img2_u).astype(dtype=np.float32)
rmse = np.sqrt(np.mean(np.power(imdiff[:], 2)))
return 20.0 * np.log10(255.0 / rmse)
def filter_image(image, kernel, mode='valid'):
""" Implements color filtering (convolution using a flipped kernel) """
chs = []
for d in range(image.shape[2]):
channel = sig.convolve2d(image[:,:,d], np.flipud(np.fliplr(kernel)), mode=mode)
chs.append(channel)
return np.stack(chs, axis=2)
def convolve_image(image, kernel, mode='valid'):
""" Implements color image convolution """
chs = []
for d in range(image.shape[2]):
channel = sig.convolve2d(image[:,:,d], kernel, mode=mode)
chs.append(channel)
return np.stack(chs, axis=2)
def DMSPDeblur(degraded, kernel, sigma_d, params):
""" Implements stochastic gradient descent (SGD) Bayes risk minimization for image deblurring described in:
"Deep Mean-Shift Priors for Image Restoration" (http://home.inf.unibe.ch/~bigdeli/DMSPrior.html)
S. A. Bigdeli, M. Jin, P. Favaro, M. Zwicker, Advances in Neural Information Processing Systems (NIPS), 2017
Input:
degraded: Observed degraded RGB input image in range of [0, 255].
kernel: Blur kernel (internally flipped for convolution).
sigma_d: Noise standard deviation. (set to -1 for noise-blind deblurring)
params: Set of parameters.
params.denoiser: The denoiser function hanlde.
Optional parameters:
params.sigma_dae: The standard deviation of the denoiser training noise. default: 11
params.num_iter: Specifies number of iterations.
params.mu: The momentum for SGD optimization. default: 0.9
params.alpha the step length in SGD optimization. default: 0.1
Outputs:
res: Solution."""
if 'denoiser' not in params:
raise ValueError('Need a denoiser in params.denoiser!')
if 'gt' in params:
print_iter = True
else:
print_iter = False
if 'sigma_dae' not in params:
params['sigma_dae'] = 11.0
if 'num_iter' not in params:
params['num_iter'] = 10
if 'mu' not in params:
params['mu'] = 0.9
if 'alpha' not in params:
params['alpha'] = 0.1
pad_y = np.floor(kernel.shape[0] / 2.0).astype(np.int64)
pad_x = np.floor(kernel.shape[1] / 2.0).astype(np.int64)
res = np.pad(degraded, pad_width=((pad_y, pad_y), (pad_x, pad_x), (0, 0)), mode='edge').astype(np.float32)
step = np.zeros(res.shape)
if print_iter:
psnr = computePSNR(params['gt'], res, pad_y, pad_x)
print ('Initialized with PSNR: ' + str(psnr))
for iter in range(params['num_iter']):
if print_iter:
print('Running iteration: ' + str(iter))
t = time.time()
# compute prior gradient
noise = np.random.normal(0.0, params['sigma_dae'], res.shape).astype(np.float32)
rec = params['denoiser'].denoise(res + noise)
prior_grad = res - rec
# compute data gradient
map_conv = filter_image(res, kernel)
data_err = map_conv - degraded
data_grad = convolve_image(data_err, kernel, mode='full')
relative_weight = 0.5
if sigma_d < 0:
sigma2 = 2 * params['sigma_dae'] * params['sigma_dae']
lambda_ = (degraded.size) / (
np.sum(np.power(data_err[:], 2)) + degraded.size * sigma2 * (np.sum(np.power(kernel[:], 2))))
relative_weight = lambda_ / (lambda_ + 1 / params['sigma_dae'] / params['sigma_dae'])
else:
relative_weight = (1 / sigma_d / sigma_d) / (
1 / sigma_d / sigma_d + 1 / params['sigma_dae'] / params['sigma_dae'])
# sum the gradients
grad_joint = data_grad * relative_weight + prior_grad * (1 - relative_weight);
# update
step = params['mu'] * step - params['alpha'] * grad_joint;
res = res + step;
res = np.minimum(255.0, np.maximum(0, res)).astype(np.float32);
if print_iter:
psnr = computePSNR(params['gt'], res, pad_y, pad_x)
print ('PSNR is: ' + str(psnr) + ', iteration finished in ' + str(time.time() - t) + ' seconds')
return res