-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_otb_defense.py
113 lines (96 loc) · 4.71 KB
/
test_otb_defense.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
# --------------------------------------------------------
# DaSiamRPN
# Licensed under The MIT License
# Written by Qiang Wang (wangqiang2015 at ia.ac.cn)
# --------------------------------------------------------
#!/usr/bin/python
import argparse, cv2, torch, json
import numpy as np
from os import makedirs
from os.path import realpath, dirname, join, isdir, exists
from net import SiamRPNotb
from run_defense import SiamRPN_init, SiamRPN_track
from utils import rect_2_cxy_wh, cxy_wh_2_rect
parser = argparse.ArgumentParser(description='PyTorch SiamRPN OTB Test')
parser.add_argument('--dataset', dest='dataset', default='OTB2015', help='datasets')
parser.add_argument('-v', '--visualization', dest='visualization', action='store_true',
help='whether visualize result')
def track_video(model, video):
image_save = 0
toc, regions = 0, []
image_files, gt = video['image_files'], video['gt']
for f, image_file in enumerate(image_files):
im = cv2.imread(image_file) # TODO: batch load
tic = cv2.getTickCount()
if f == 0: # init
target_pos, target_sz = rect_2_cxy_wh(gt[f])
state = SiamRPN_init(im, target_pos, target_sz, model) # init tracker
location = cxy_wh_2_rect(state['target_pos'], state['target_sz'])
regions.append(gt[f])
att_per = 0 # adversarial perturbation in attack
def_per = 0 # adversarial perturbation in defense
elif f > 0: # tracking
if f % 30 == 1: # clean the perturbation from last frame
att_per = 0
def_per = 0
state, att_per, def_per = SiamRPN_track(state, im, f, regions[f-1], att_per, def_per, image_save, iter=10) # gt_track
location = cxy_wh_2_rect(state['target_pos']+1, state['target_sz'])
regions.append(location)
else:
state, att_per, def_per = SiamRPN_track(state, im, f, regions[f-1], att_per, def_per, image_save, iter=5) # gt_track
location = cxy_wh_2_rect(state['target_pos']+1, state['target_sz'])
regions.append(location)
toc += cv2.getTickCount() - tic
if args.visualization and f >= 0: # visualization
if f == 0: cv2.destroyAllWindows()
if len(gt[f]) == 8:
cv2.polylines(im, [np.array(gt[f], np.int).reshape((-1, 1, 2))], True, (0, 255, 0), 2)
else:
cv2.rectangle(im, (gt[f, 0], gt[f, 1]), (gt[f, 0] + gt[f, 2], gt[f, 1] + gt[f, 3]), (0, 255, 0), 2)
if len(location) == 8:
cv2.polylines(im, [location.reshape((-1, 1, 2))], True, (0, 255, 255), 2)
else:
location = [int(l) for l in location] #
cv2.rectangle(im, (location[0], location[1]),
(location[0] + location[2], location[1] + location[3]), (0, 255, 255), 2)
cv2.putText(im, str(f), (40, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 255), 2)
cv2.imshow(video['name'], im)
cv2.waitKey(1)
toc /= cv2.getTickFrequency()
# save result
video_path = join('test', args.dataset, 'DaSiamRPN_defense')
if not isdir(video_path): makedirs(video_path)
result_path = join(video_path, '{:s}.txt'.format(video['name']))
with open(result_path, "w") as fin:
for x in regions:
fin.write(','.join([str(i) for i in x])+'\n')
print('({:d}) Video: {:12s} Time: {:02.1f}s Speed: {:3.1f}fps'.format(
v_id, video['name'], toc, f / toc))
return f / toc
def load_dataset(dataset):
base_path = join(realpath(dirname(__file__)), 'data', dataset)
if not exists(base_path):
print("Please download OTB dataset into `data` folder!")
exit()
json_path = join(realpath(dirname(__file__)), 'data', dataset + '.json')
info = json.load(open(json_path, 'r'))
for v in info.keys():
path_name = info[v]['name']
info[v]['image_files'] = [join(base_path, path_name, 'img', im_f) for im_f in info[v]['image_files']]
info[v]['gt'] = np.array(info[v]['gt_rect'])-[1,1,0,0] # our tracker is 0-index
info[v]['name'] = v
return info
def main():
global args, v_id
args = parser.parse_args()
net = SiamRPNotb()
net.load_state_dict(torch.load(join(realpath(dirname(__file__)), 'SiamRPNOTB.model')))
#net.load_state_dict(torch.load(join(realpath(dirname(__file__)), 'SiamRPNVOT.model')))
net.eval().cuda()
dataset = load_dataset(args.dataset)
fps_list = []
for v_id, video in enumerate(dataset.keys()):
fps_list.append(track_video(net, dataset[video]))
print('Mean Running Speed {:.1f}fps'.format(np.mean(np.array(fps_list))))
if __name__ == '__main__':
main()