-
Notifications
You must be signed in to change notification settings - Fork 611
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Skip two-view re-estimation by default
- Loading branch information
Showing
7 changed files
with
1,012,116 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
from dsfm.models import get_model | ||
|
||
from ..utils.base_model import BaseModel | ||
|
||
|
||
class InternalModel(BaseModel): | ||
default_conf = { | ||
'model_name': None, | ||
'filter': [], | ||
} | ||
required_inputs = ['image'] | ||
|
||
def _init(self, conf): | ||
assert conf['model_name'] is not None | ||
conf_ = {k: conf[k] for k in conf if k not in ['model_name', 'filter']} | ||
self.net = get_model(conf['model_name'])(conf_) | ||
|
||
def _forward(self, data): | ||
pred = self.net(data) | ||
if self.conf['filter']: | ||
pred = {k: pred[k] for k in pred if k in self.conf['filter']} | ||
return pred |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,155 @@ | ||
import argparse | ||
import torch | ||
from pathlib import Path | ||
import h5py | ||
import logging | ||
from tqdm import tqdm | ||
import pprint | ||
|
||
from . import matchers | ||
from .utils.base_model import dynamic_load | ||
from .utils.parsers import names_to_pair | ||
from .utils.tools import map_tensor | ||
|
||
|
||
''' | ||
A set of standard configurations that can be directly selected from the command | ||
line using their name. Each is a dictionary with the following entries: | ||
- output: the name of the match file that will be generated. | ||
- model: the model configuration, as passed to a feature matcher. | ||
''' | ||
confs = { | ||
'superglue': { | ||
'output': 'matches-superglue', | ||
'model': { | ||
'name': 'superglue', | ||
'weights': 'outdoor', | ||
'sinkhorn_iterations': 50, | ||
}, | ||
}, | ||
'NN': { | ||
'output': 'matches-NN-mutual-dist.7', | ||
'model': { | ||
'name': 'nearest_neighbor', | ||
'mutual_check': True, | ||
'distance_threshold': 0.7, | ||
}, | ||
} | ||
} | ||
|
||
|
||
class FeatureLoader(torch.utils.data.Dataset): | ||
def __init__(self, pairs, feature_path): | ||
self.pairs = pairs | ||
self.feature_path = feature_path | ||
|
||
def __getitem__(self, idx): | ||
data = {} | ||
name0, name1 = self.pairs[idx] | ||
data['pair'] = names_to_pair(name0, name1) | ||
|
||
with h5py.File(self.feature_path, 'r') as feature_file: | ||
for name, suffix in [(name0, '0'), (name1, '1')]: | ||
feats = feature_file[name] | ||
for k in feats.keys(): | ||
x = feats[k].__array__() | ||
data[k+suffix] = torch.from_numpy(x).float() | ||
# some matchers might expect an image but only use its size | ||
image_shape = (1,) + tuple(feats['image_size'])[::-1] | ||
data['image'+suffix] = torch.empty(image_shape) | ||
|
||
return data | ||
|
||
def __len__(self): | ||
return len(self.pairs) | ||
|
||
|
||
@torch.no_grad() | ||
def main(conf, pairs, features, export_dir, exhaustive=False): | ||
logging.info('Matching local features with configuration:' | ||
f'\n{pprint.pformat(conf)}') | ||
|
||
feature_path = Path(export_dir, features+'.h5') | ||
assert feature_path.exists(), feature_path | ||
pairs_name = pairs.stem | ||
match_name = f'{features}_{conf["output"]}_{pairs_name}' | ||
match_path = Path(export_dir, match_name+'.h5') | ||
|
||
if not exhaustive: | ||
assert pairs.exists(), pairs | ||
with open(pairs, 'r') as f: | ||
pair_list = f.read().rstrip('\n').split('\n') | ||
elif exhaustive: | ||
logging.info(f'Writing exhaustive match pairs to {pairs}.') | ||
assert not pairs.exists(), pairs | ||
|
||
# get the list of images from the feature file | ||
images = [] | ||
with h5py.File(str(feature_path), 'r') as feature_file: | ||
feature_file.visititems( | ||
lambda name, obj: images.append(obj.parent.name.strip('/')) | ||
if isinstance(obj, h5py.Dataset) else None) | ||
images = list(set(images)) | ||
|
||
pair_list = [' '.join((images[i], images[j])) | ||
for i in range(len(images)) for j in range(i)] | ||
with open(str(pairs), 'w') as f: | ||
f.write('\n'.join(pair_list)) | ||
|
||
device = 'cuda' if torch.cuda.is_available() else 'cpu' | ||
Model = dynamic_load(matchers, conf['model']['name']) | ||
model = Model(conf['model']).eval().to(device) | ||
|
||
# precompute the list of pairs missing in the file | ||
pair_list_unique = [] | ||
matched = set() | ||
for pair in pair_list: | ||
name0, name1 = pair.split(' ') | ||
pair = names_to_pair(name0, name1) | ||
# Avoid to recompute duplicates to save time | ||
if len({(name0, name1), (name1, name0)} & matched): | ||
continue | ||
if match_path.exists(): | ||
with h5py.File(str(match_path), 'r') as match_file: | ||
if pair in match_file: | ||
continue | ||
pair_list_unique.append((name0, name1)) | ||
matched |= {(name0, name1), (name1, name0)} | ||
|
||
loader = FeatureLoader(pair_list_unique, feature_path) | ||
loader = torch.utils.data.DataLoader( | ||
loader, num_workers=1, pin_memory=True) | ||
|
||
for data in tqdm(loader, smoothing=.1): | ||
data_ = map_tensor(data, lambda x: x.to(device)) | ||
pred = model(data_) | ||
pair = data['pair'][0] | ||
|
||
with h5py.File(str(match_path), 'a') as match_file: | ||
grp = match_file.create_group(pair) | ||
matches = pred['matches0'][0].cpu().short().numpy() | ||
grp.create_dataset('matches0', data=matches) | ||
|
||
if 'matching_scores0' in pred: | ||
scores = pred['matching_scores0'][0].cpu().half().numpy() | ||
grp.create_dataset('matching_scores0', data=scores) | ||
|
||
match_file.close() | ||
logging.info('Finished exporting matches.') | ||
|
||
return match_path | ||
|
||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--export_dir', type=Path, required=True) | ||
parser.add_argument('--features', type=str, | ||
default='feats-superpoint-n4096-r1024') | ||
parser.add_argument('--pairs', type=Path, required=True) | ||
parser.add_argument('--conf', type=str, default='superglue', | ||
choices=list(confs.keys())) | ||
parser.add_argument('--exhaustive', action='store_true') | ||
args = parser.parse_args() | ||
main( | ||
confs[args.conf], args.pairs, args.features, args.export_dir, | ||
exhaustive=args.exhaustive) |
Oops, something went wrong.