-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy patheval_gta.py
256 lines (196 loc) · 10.2 KB
/
eval_gta.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
import os
import sys
import torch
import argparse
from dataclasses import dataclass
from torch.utils.data import DataLoader
from game4loc.dataset.gta import GTADatasetEval, get_transforms
from game4loc.evaluate.gta import evaluate
from game4loc.models.model import DesModel
def parse_tuple(s):
try:
return tuple(map(int, s.split(',')))
except ValueError:
raise argparse.ArgumentTypeError("Tuple must be integers separated by commas")
@dataclass
class Configuration:
# Model
# model: str = 'convnext_base.fb_in22k_ft_in1k_384'
model: str = 'vit_base_patch16_rope_reg1_gap_256.sbb_in1k'
# Override model image size
img_size: int = 384
# Evaluation
batch_size: int = 128
verbose: bool = True
gpu_ids: tuple = (0)
normalize_features: bool = True
# With Fine Matching
with_match: bool = False
# set num_workers to 0 if on Windows
num_workers: int = 0 if os.name == 'nt' else 4
# train on GPU if available
device: str = 'cuda' if torch.cuda.is_available() else 'cpu'
# Dataset
query_mode: str = 'D2S'
# query_mode: str = 'S2D'
# Checkpoint to start from
# checkpoint_start = '/home/xmuairmud/jyx/GTA-UAV/Game4Loc/pretrained/gta/same_area/selavpr.pth'
checkpoint_start = 'pretrained/gta/cross_area/game4loc.pth'
# data_root: str = "/home/xmuairmud/data/GTA-UAV-data/GTA-UAV-Lidar/GTA-UAV-Lidar"
data_root: str = "/home/xmuairmud/data/GTA-UAV-data/GTA-UAV-official/GTA-UAV-LR-hf"
train_pairs_meta_file = 'cross-area-drone2sate-train.json'
test_pairs_meta_file = 'cross-area-drone2sate-test.json'
sate_img_dir = 'satellite'
def eval_script(config):
if config.log_to_file:
f = open(config.log_path, 'w')
sys.stdout = f
#-----------------------------------------------------------------------------#
# Model #
#-----------------------------------------------------------------------------#
print("\nModel: {}".format(config.model))
model = DesModel(config.model,
pretrained=True,
img_size=config.img_size,
share_weights=config.share_weights)
data_config = model.get_config()
print(data_config)
mean = data_config["mean"]
std = data_config["std"]
img_size = (config.img_size, config.img_size)
# load pretrained Checkpoint
if config.checkpoint_start is not None:
print("Start from:", config.checkpoint_start)
model_state_dict = torch.load(config.checkpoint_start)
model.load_state_dict(model_state_dict, strict=True)
# Data parallel
print("GPUs available:", torch.cuda.device_count())
if torch.cuda.device_count() > 1 and len(config.gpu_ids) > 1:
model = torch.nn.DataParallel(model, device_ids=config.gpu_ids)
# Model to device
model = model.to(config.device)
print("\nImage Size Query:", img_size)
print("Image Size Ground:", img_size)
print("Mean: {}".format(mean))
print("Std: {}\n".format(std))
#-----------------------------------------------------------------------------#
# DataLoader #
#-----------------------------------------------------------------------------#
# Transforms
val_transforms, train_sat_transforms, train_drone_transforms = get_transforms(img_size, mean=mean, std=std)
# Test query
if config.query_mode == 'D2S':
query_dataset_test = GTADatasetEval(data_root=config.data_root,
pairs_meta_file=config.test_pairs_meta_file,
view="drone",
transforms=val_transforms,
mode='pos',
query_mode=config.query_mode,
)
gallery_dataset_test = GTADatasetEval(data_root=config.data_root,
pairs_meta_file=config.test_pairs_meta_file,
view="sate",
transforms=val_transforms,
sate_img_dir=config.sate_img_dir,
mode='pos',
query_mode=config.query_mode,
)
pairs_dict = query_dataset_test.pairs_drone2sate_dict
elif config.query_mode == 'S2D':
gallery_dataset_test = GTADatasetEval(data_root=config.data_root,
pairs_meta_file=config.test_pairs_meta_file,
view="drone",
transforms=val_transforms,
mode='pos',
query_mode=config.query_mode,
)
pairs_dict = gallery_dataset_test.pairs_sate2drone_dict
query_dataset_test = GTADatasetEval(data_root=config.data_root,
pairs_meta_file=config.test_pairs_meta_file,
view="sate",
transforms=val_transforms,
query_mode=config.query_mode,
pairs_sate2drone_dict=pairs_dict,
sate_img_dir=config.sate_img_dir,
mode='pos',
)
query_img_list = query_dataset_test.images_name
query_center_loc_xy_list = query_dataset_test.images_center_loc_xy
gallery_center_loc_xy_list = gallery_dataset_test.images_center_loc_xy
gallery_topleft_loc_xy_list = gallery_dataset_test.images_topleft_loc_xy
gallery_img_list = gallery_dataset_test.images_name
query_dataloader_test = DataLoader(query_dataset_test,
batch_size=config.batch_size,
num_workers=config.num_workers,
shuffle=False,
pin_memory=True)
gallery_dataloader_test = DataLoader(gallery_dataset_test,
batch_size=config.batch_size,
num_workers=config.num_workers,
shuffle=False,
pin_memory=True)
print("Query Images Test:", len(query_dataset_test))
print("Gallery Images Test:", len(gallery_dataset_test))
# For Test Log (distance threshold)
dis_threshold_list = None
if 'cross' in config.test_pairs_meta_file:
####### Cross-area for total 500m/10m
print("cross-area eval")
dis_threshold_list = [10*(i+1) for i in range(50)]
else:
####### Same-area for total 200m/4m
print("same-area eval")
dis_threshold_list = [4*(i+1) for i in range(50)]
print("\n{}[{}]{}".format(30*"-", "Evaluating GTA-UAV", 30*"-"))
r1_test = evaluate(config=config,
model=model,
query_loader=query_dataloader_test,
gallery_loader=gallery_dataloader_test,
query_list=query_img_list,
gallery_list=gallery_img_list,
pairs_dict=pairs_dict,
ranks_list=[1, 5, 10],
query_center_loc_xy_list=query_center_loc_xy_list,
gallery_center_loc_xy_list=gallery_center_loc_xy_list,
gallery_topleft_loc_xy_list=gallery_topleft_loc_xy_list,
step_size=1000,
dis_threshold_list=dis_threshold_list,
cleanup=True,
plot_acc_threshold=False,
top10_log=False,
with_match=config.with_match)
if config.log_to_file:
f.close()
sys.stdout = sys.__stdout__
def parse_args():
parser = argparse.ArgumentParser(description="Training script for gta.")
parser.add_argument('--log_to_file', action='store_true', help='Log saving to file')
parser.add_argument('--log_path', type=str, default=None, help='Log file path')
parser.add_argument('--data_root', type=str, default='./data/GTA-UAV-data', help='Data root')
parser.add_argument('--test_pairs_meta_file', type=str, default='cross-area-drone2sate-test.json', help='Test metafile path')
parser.add_argument('--model', type=str, default='vit_base_patch16_rope_reg1_gap_256.sbb_in1k', help='Model architecture')
parser.add_argument('--no_share_weights', action='store_true', help='Model not sharing wieghts')
parser.add_argument('--with_match', action='store_true', help='Test with post-process image matching (GIM, etc)')
parser.add_argument('--gpu_ids', type=parse_tuple, default=(0,1), help='GPU ID')
parser.add_argument('--batch_size', type=int, default=40, help='Batch size')
parser.add_argument('--checkpoint_start', type=str, default=None, help='Training from checkpoint')
parser.add_argument('--test_mode', type=str, default='pos', help='Test with positive pairs')
parser.add_argument('--query_mode', type=str, default='D2S', help='Retrieval with drone to satellite')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
config = Configuration()
config.data_root = args.data_root
config.test_pairs_meta_file = args.test_pairs_meta_file
config.log_to_file = args.log_to_file
config.log_path = args.log_path
config.batch_size = args.batch_size
config.gpu_ids = args.gpu_ids
config.checkpoint_start = args.checkpoint_start
config.model = args.model
config.share_weights = not(args.no_share_weights)
config.test_mode = args.test_mode
config.query_mode = args.query_mode
config.with_match = args.with_match
eval_script(config)