-
Notifications
You must be signed in to change notification settings - Fork 37
/
Copy pathtest_real_data.py
92 lines (76 loc) · 3.38 KB
/
test_real_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
from __future__ import print_function, division
import argparse
import logging
import numpy as np
import cv2
import os
from pathlib import Path
from tqdm import tqdm
from lib.human_loader import StereoHumanDataset
from lib.network import RtStereoHumanModel
from config.stereo_human_config import ConfigStereoHuman as config
from lib.utils import get_novel_calib
from lib.GaussianRender import pts2render
import torch
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
class StereoHumanRender:
def __init__(self, cfg_file, phase):
self.cfg = cfg_file
self.bs = self.cfg.batch_size
self.model = RtStereoHumanModel(self.cfg, with_gs_render=True)
self.dataset = StereoHumanDataset(self.cfg.dataset, phase=phase)
self.model.cuda()
if self.cfg.restore_ckpt:
self.load_ckpt(self.cfg.restore_ckpt)
self.model.eval()
def infer_seqence(self, view_select, ratio=0.5):
total_frames = len(os.listdir(os.path.join(self.cfg.dataset.test_data_root, 'img')))
for idx in tqdm(range(total_frames)):
item = self.dataset.get_test_item(idx, source_id=view_select)
data = self.fetch_data(item)
data = get_novel_calib(data, self.cfg.dataset, ratio=ratio, intr_key='intr_ori', extr_key='extr_ori')
with torch.no_grad():
data, _, _ = self.model(data, is_train=False)
data = pts2render(data, bg_color=self.cfg.dataset.bg_color)
render_novel = self.tensor2np(data['novel_view']['img_pred'])
cv2.imwrite(self.cfg.test_out_path + '/%s_novel.jpg' % (data['name']), render_novel)
def tensor2np(self, img_tensor):
img_np = img_tensor.permute(0, 2, 3, 1)[0].detach().cpu().numpy()
img_np = img_np * 255
img_np = img_np[:, :, ::-1].astype(np.uint8)
return img_np
def fetch_data(self, data):
for view in ['lmain', 'rmain']:
for item in data[view].keys():
data[view][item] = data[view][item].cuda().unsqueeze(0)
return data
def load_ckpt(self, load_path):
assert os.path.exists(load_path)
logging.info(f"Loading checkpoint from {load_path} ...")
ckpt = torch.load(load_path, map_location='cuda')
self.model.load_state_dict(ckpt['network'], strict=True)
logging.info(f"Parameter loading done")
if __name__ == '__main__':
logging.basicConfig(level=logging.INFO,
format='%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s')
parser = argparse.ArgumentParser()
parser.add_argument('--test_data_root', type=str, required=True)
parser.add_argument('--ckpt_path', type=str, required=True)
parser.add_argument('--src_view', type=int, nargs='+', required=True)
parser.add_argument('--ratio', type=float, default=0.5)
arg = parser.parse_args()
cfg = config()
cfg_for_train = os.path.join('./config', 'stage2.yaml')
cfg.load(cfg_for_train)
cfg = cfg.get_cfg()
cfg.defrost()
cfg.batch_size = 1
cfg.dataset.test_data_root = arg.test_data_root
cfg.dataset.use_processed_data = False
cfg.restore_ckpt = arg.ckpt_path
cfg.test_out_path = './test_out'
Path(cfg.test_out_path).mkdir(exist_ok=True, parents=True)
cfg.freeze()
render = StereoHumanRender(cfg, phase='test')
render.infer_seqence(view_select=arg.src_view, ratio=arg.ratio)