-
Notifications
You must be signed in to change notification settings - Fork 172
/
Copy pathimage.py
149 lines (132 loc) · 4.84 KB
/
image.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
## Copyright 2018 Intel Corporation
## SPDX-License-Identifier: Apache-2.0
import os
import numpy as np
import torch
import OpenImageIO as oiio
from ssim import ssim, ms_ssim
from util import *
## -----------------------------------------------------------------------------
## Image operations
## -----------------------------------------------------------------------------
# Converts a NumPy image to a tensor
def image_to_tensor(image, batch=False):
# Reorder from HWC to CHW
tensor = torch.from_numpy(image.transpose((2, 0, 1)))
if batch:
return tensor.unsqueeze(0) # reshape to NCHW
else:
return tensor
# Converts a tensor to a NumPy image
def tensor_to_image(image):
if len(image.shape) == 4:
# Remove N dimension
image = image.squeeze(0)
# Reorder from CHW to HWC
return image.cpu().numpy().transpose((1, 2, 0))
# Computes gradient for a tensor
def tensor_gradient(input):
input0 = input[..., :-1, :-1]
didy = input[..., 1:, :-1] - input0
didx = input[..., :-1, 1:] - input0
return torch.cat((didy, didx), -3)
# Compares two image tensors using the specified error metric
def compare_images(a, b, metric='psnr'):
if metric == 'mse':
return ((a - b) ** 2).mean()
elif metric == 'psnr':
mse = ((a - b) ** 2).mean()
return 10 * np.log10(1. / mse.item())
elif metric == 'ssim':
return ssim(a, b, data_range=1.)
elif metric == 'msssim':
return ms_ssim(a, b, data_range=1.)
else:
raise ValueError('invalid error metric')
## -----------------------------------------------------------------------------
## Image I/O
## -----------------------------------------------------------------------------
# Loads an image and returns the pixels as a float NumPy array and the number of loaded channels
def load_image(filename, num_channels=None):
input = oiio.ImageInput.open(filename)
if not input:
raise RuntimeError('could not open image: "' + filename + '"')
load_num_channels = min(input.spec().nchannels, 3)
if num_channels:
load_num_channels = min(load_num_channels, num_channels)
image = input.read_image(subimage=0, miplevel=0, chbegin=0, chend=load_num_channels, format=oiio.FLOAT)
if image is None:
raise RuntimeError('could not read image')
input.close()
if num_channels and image.shape[2] < num_channels:
# Repeat the last channel to fill in the missing channels
repeats = [1] * (image.shape[2] - 1) + [num_channels - image.shape[2] + 1]
image = np.repeat(image, repeats, axis=2)
image = np.nan_to_num(image)
return image, load_num_channels
# Saves a float NumPy image
def save_image(filename, image, num_channels=None):
if num_channels and num_channels != image.shape[2]:
if image.shape[2] < num_channels:
raise RuntimeError('image to save has fewer channels than expected')
elif num_channels == 1:
# Compute the average of all channels
image = np.mean(image, axis=2, keepdims=True)
else:
# Truncate to the specified number of channels
image = image[:, :, :num_channels]
ext = get_path_ext(filename).lower()
if ext == 'pfm':
save_pfm(filename, image)
elif ext == 'phm':
save_phm(filename, image)
else:
output = oiio.ImageOutput.create(filename)
if not output:
raise RuntimeError('could not create image: "' + filename + '"')
format = oiio.FLOAT if ext == 'exr' else oiio.UINT8
spec = oiio.ImageSpec(image.shape[1], image.shape[0], image.shape[2], format)
if ext == 'exr':
spec.attribute('compression', 'piz')
elif ext == 'png':
spec.attribute('png:compressionLevel', 3)
if not output.open(filename, spec):
raise RuntimeError('could not open image: "' + filename + '"')
# FIXME: copy is needed for arrays owned by PyTorch for some reason
if not output.write_image(image.copy()):
raise RuntimeError('could not save image')
output.close()
# Saves a float NumPy image in PFM format
def save_pfm(filename, image):
with open(filename, 'w') as f:
num_channels = image.shape[-1]
if num_channels >= 3:
f.write('PF\n')
data = image[..., 0:3]
elif num_channels == 1:
f.write('Pf\n')
data = image[..., 0]
else:
f.write('P=\n') # non-standard 2-channel format
data = image
data = np.flip(data, 0).astype(np.float32)
f.write('%d %d\n' % (image.shape[1], image.shape[0]))
f.write('-1.0\n')
data.tofile(f)
# Saves a float NumPy image in PHM format
def save_phm(filename, image):
with open(filename, 'w') as f:
num_channels = image.shape[-1]
if num_channels >= 3:
f.write('PH\n')
data = image[..., 0:3]
elif num_channels == 1:
f.write('Ph\n')
data = image[..., 0]
else:
f.write('P:\n') # non-standard 2-channel format
data = image
data = np.flip(data, 0).astype(np.float16)
f.write('%d %d\n' % (image.shape[1], image.shape[0]))
f.write('-1.0\n')
data.tofile(f)