-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathtest.py
executable file
·89 lines (77 loc) · 3.69 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import torch
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
import pdb, os, argparse
from scipy import misc
from model.ResNet_models import Generator, Descriptor
from data import test_dataset
os.environ["CUDA_VISIBLE_DEVICES"] = '1'
import cv2
parser = argparse.ArgumentParser()
parser.add_argument('--testsize', type=int, default=352, help='testing size')
parser.add_argument('--latent_dim', type=int, default=6, help='latent dim')
parser.add_argument('--channel_reduced_gen', type=int, default=32, help='reduced channel dimension for generator')
parser.add_argument('--channel_reduced_des', type=int, default=64, help='reduced channel dimension for descriptor')
parser.add_argument('--energy_form', default='identity', help='tanh | sigmoid | identity | softplus')
parser.add_argument('-sigma_des', type=float, default=0.1,help='sigma of EBM langevin')
parser.add_argument('--langevin_step_num_des', type=int, default=3, help='number of langevin steps for ebm')
parser.add_argument('-langevin_step_size_des', type=float, default=0.001,help='step size of EBM langevin')
parser.add_argument('--z_sample_iterations', type=int, default=10, help='number of iterations for sampling z from latent space')
opt = parser.parse_args()
dataset_path = './test/img/'
gt_path = './test/gt/'
generator = Generator(channel=opt.channel_reduced_gen, latent_dim=opt.latent_dim)
descriptor = Descriptor(channel=opt.channel_reduced_des)
generator.load_state_dict(torch.load('./models/Resnet/Model_30_gen.pth'))
descriptor.load_state_dict(torch.load('./models/Resnet/Model_30_des.pth'))
generator.cuda()
generator.eval()
descriptor.cuda()
descriptor.eval()
test_datasets = ['ECSSD','DUT','DUTS_Test','HKU-IS', 'PASCAL', 'SOD']
def compute_energy(disc_score):
if opt.energy_form == 'tanh':
energy = torch.tanh(-disc_score.squeeze())
elif opt.energy_form == 'sigmoid':
energy = F.sigmoid(-disc_score.squeeze())
elif opt.energy_form == 'identity':
energy = -disc_score.squeeze()
elif opt.energy_form == 'softplus':
energy = F.softplus(-disc_score.squeeze())
return energy
for dataset in test_datasets:
save_path_mean = './results_mean/' + dataset + '/'
save_path_var = './results_var/' + dataset + '/'
if not os.path.exists(save_path_mean):
os.makedirs(save_path_mean)
if not os.path.exists(save_path_var):
os.makedirs(save_path_var)
image_root = dataset_path + dataset + '/'
test_loader = test_dataset(image_root, opt.testsize)
for i in range(test_loader.size):
print(i)
image, HH, WW, name = test_loader.load_data()
image = image.cuda()
sal_pred = list()
for i in range(10):
z_noise = torch.zeros(image.shape[0], opt.latent_dim).cuda()
_, generator_pred = generator.forward(image, z_noise)
seg1 = generator_pred
temp = seg1
sal_pred.append(torch.sigmoid(temp))
sal_preds = sal_pred[0].clone()
for iter in range(1, 10):
sal_preds = torch.cat((sal_preds, sal_pred[iter]), 1)
mean_pred = torch.mean(sal_preds, dim=1, keepdim=True)
var = -mean_pred * torch.log(mean_pred + 1e-8)
res = mean_pred
res = F.upsample(res, size=[WW, HH], mode='bilinear', align_corners=False)
res = res.data.cpu().numpy().squeeze()
res = 255 * (res - res.min()) / (res.max() - res.min() + 1e-8)
cv2.imwrite(save_path_mean + name, res)
res = var
res = F.upsample(res, size=[WW, HH], mode='bilinear', align_corners=False)
res = res.data.cpu().numpy().squeeze()
res = 255 * (res - res.min()) / (res.max() - res.min() + 1e-8)
cv2.imwrite(save_path_var + name, res)