-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathexample.py
158 lines (128 loc) · 6.01 KB
/
example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import sys
import torch
from networks.relighting import Relighting
from third_party.wrappers import ImageCropPoser, WrapperLightingEstimator
from PIL import Image
import numpy as np
import imageio
from tqdm import tqdm
import glob
from utils import render_tensor, render_half_sphere, paste_light_on_img_tensor
torch.set_float32_matmul_precision("high")
torch.backends.cudnn.benchmark = True
def main():
# 1. Initialize models
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cropposer, relighting, dpr = initialize_models(device)
example_lightings = np.load("examples/example_lightings.npy")
# 2. Warm up models
warm_up_models(cropposer, relighting, dpr, device)
# 3. Reconstruct image (i.e. render under original camera and lighting)
cropped_image, cam, planes = reconstruct_image(cropposer, relighting, dpr, device)
# 4. Relighting images under novel views and lightings
perform_relighting(cropped_image, planes, relighting, device, example_lightings)
# 5. Relighting video under novel lightings
perform_video_relighting(relighting, device, example_lightings)
def initialize_models(device):
cropposer = ImageCropPoser(device).to(device)
relighting = Relighting(device).to(device)
relighting.load_state_dict(torch.load("checkpoints/model.pth"))
dpr = WrapperLightingEstimator(device).to(device)
return cropposer, relighting, dpr
def warm_up_models(cropposer, relighting, dpr, device):
with torch.no_grad():
example_img = Image.open("examples/example.png")
example_img = preprocess_image(example_img, device)
cropposer.wild2all(example_img)
relighting.image_forward(
example_img,
torch.rand(1, 25, device=device),
torch.rand(1, 9, device=device),
)
dpr.dpr.extract_lighting(example_img)
print("Models loaded")
def preprocess_image(image, device):
image = torch.from_numpy(np.array(image)).permute(2, 0, 1).unsqueeze(0)
image = image.float().to(device) / 255 * 2 - 1
return image
def reconstruct_image(cropposer, relighting, dpr, device):
with torch.no_grad():
input_image = Image.open("examples/example.png")
input_image = preprocess_image(input_image, device)
ret = cropposer.wild2all(input_image)
cropped_image = ret["img_cropped"].to(device)
cam = ret["cam"].to(device)
sh = dpr.dpr.extract_lighting(cropped_image).squeeze().unsqueeze(0)
ret, planes, _ = relighting.image_forward(cropped_image, cam, sh)
recon_image = ret["image"]
render_tensor(recon_image).save("examples/recon_image.png")
render_tensor(cropped_image).save("examples/cropped_image.png")
return cropped_image, cam, planes
def perform_relighting(cropped_image, planes, relighting, device, example_lightings):
fps = 24
frames_per_step = 36
steps = 3
total_frames = steps * frames_per_step
with imageio.get_writer("examples/relighting.mp4", fps=fps) as video_writer:
for step in tqdm(range(total_frames)):
with torch.inference_mode():
idx = (step // frames_per_step) % len(example_lightings)
angle = (step % frames_per_step) / frames_per_step * 2 * np.pi
pitch = 0.3 * np.sin(angle)
yaw = 0.3 * np.cos(angle)
cam = relighting.encoders.eg3d.args2cam(pitch=pitch, yaw=yaw).to(device)
sh = torch.from_numpy(example_lightings[idx]).unsqueeze(0).to(device)
if device.type == "cuda":
with torch.cuda.amp.autocast():
ret, _, _ = relighting.image_forward(
cropped_image, cam, sh, gt_planes=planes
)
else:
ret, _, _ = relighting.image_forward(
cropped_image, cam, sh, gt_planes=planes
)
recon_image = ret["image"]
frame = render_tensor(paste_light_on_img_tensor(64, sh, recon_image))
video_writer.append_data(np.array(frame))
def perform_video_relighting(relighting, device, example_lightings):
frames = sorted(glob.glob("examples/video/cropped/*.jpg"))
cams = sorted(glob.glob("examples/video/camera/*.npy"))
frames_per_lighting = 20
relighting.reset()
prev_cam = []
with imageio.get_writer("examples/video_relighting.mp4", fps=24) as video_writer:
pbar = tqdm(total=len(frames))
for idx, (frame_path, cam_path) in enumerate(zip(frames, cams)):
with torch.inference_mode():
frame = Image.open(frame_path)
cam = np.load(cam_path)
prev_cam.append(cam)
if len(prev_cam) > 5:
prev_cam.pop(0)
cam_avg = np.mean(prev_cam, axis=0)
frame = preprocess_image(frame, device)
cam_tensor = torch.from_numpy(cam_avg).to(device)
cur_idx = (idx // frames_per_lighting) % len(example_lightings)
next_idx = (cur_idx + 1) % len(example_lightings)
alpha = (idx % frames_per_lighting) / frames_per_lighting
cur_sh = example_lightings[cur_idx]
next_sh = example_lightings[next_idx]
sh = (
torch.from_numpy((1 - alpha) * cur_sh + alpha * next_sh)
.unsqueeze(0)
.to(device)
)
if device.type == "cuda":
with torch.cuda.amp.autocast():
ret = relighting.video_forward(frame, cam_tensor, sh)
else:
ret = relighting.video_forward(frame, cam_tensor, sh)
recon_image = ret["image"]
frame_output = render_tensor(
paste_light_on_img_tensor(64, sh, recon_image)
)
video_writer.append_data(np.array(frame_output))
pbar.update(1)
pbar.close()
if __name__ == "__main__":
main()