-
Notifications
You must be signed in to change notification settings - Fork 12
/
Copy pathtest.py
executable file
·158 lines (140 loc) · 7.42 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import os
import torch
import argparse
import torchvision
import pytorch_lightning
import numpy as np
from PIL import Image
from torch import autocast
from einops import rearrange
from functools import partial
from omegaconf import OmegaConf
from pytorch_lightning import seed_everything
from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
def un_norm(x):
return (x+1.0)/2.0
def un_norm_clip(x):
x[0,:,:] = x[0,:,:] * 0.26862954 + 0.48145466
x[1,:,:] = x[1,:,:] * 0.26130258 + 0.4578275
x[2,:,:] = x[2,:,:] * 0.27577711 + 0.40821073
return x
class DataModuleFromConfig(pytorch_lightning.LightningDataModule):
def __init__(self,
batch_size, # 1
test=None, # {...}
wrap=False, # False
shuffle=False,
shuffle_test_loader=False,
use_worker_init_fn=False):
super().__init__()
self.batch_size = batch_size
self.num_workers = batch_size * 2
self.use_worker_init_fn = use_worker_init_fn
self.wrap = wrap
self.datasets = instantiate_from_config(test)
self.dataloader = torch.utils.data.DataLoader(self.datasets,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=shuffle,
worker_init_fn=None)
if __name__ == "__main__":
# =============================================================
# 处理 opt
# =============================================================
parser = argparse.ArgumentParser()
parser.add_argument("-b", "--base", type=str, default="configs/test.yaml")
parser.add_argument("-c", "--ckpt", type=str, default="./model.ckpt")
parser.add_argument("-s", "--seed", type=int, default=42)
parser.add_argument("-d", "--ddim", type=int, default=64)
opt = parser.parse_args()
# =============================================================
# 设置 seed
# =============================================================
seed_everything(opt.seed)
# =============================================================
# 初始化 config
# =============================================================
config = OmegaConf.load(f"{opt.base}")
# =============================================================
# 加载 dataloader
# =============================================================
data = instantiate_from_config(config.data)
print(f"{data.__class__.__name__}, {len(data.dataloader)}")
# =============================================================
# 加载 model
# =============================================================
model = instantiate_from_config(config.model)
model.load_state_dict(torch.load(opt.ckpt, map_location="cpu")["state_dict"], strict=False)
model.cuda()
model.eval()
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = model.to(device)
sampler = DDIMSampler(model)
# =============================================================
# 设置精度
# =============================================================
precision_scope = autocast
# =============================================================
# 开始测试
# =============================================================
os.makedirs("results/Unpaired_Direst")
os.makedirs("results/Unpaired_Concatenation")
with torch.no_grad():
with precision_scope("cuda"):
for i,batch in enumerate(data.dataloader):
# 加载数据
inpaint = batch["inpaint_image"].to(torch.float16).to(device)
reference = batch["ref_imgs"].to(torch.float16).to(device)
mask = batch["inpaint_mask"].to(torch.float16).to(device)
hint = batch["hint"].to(torch.float16).to(device)
truth = batch["GT"].to(torch.float16).to(device)
# 数据处理
encoder_posterior_inpaint = model.first_stage_model.encode(inpaint)
z_inpaint = model.scale_factor * (encoder_posterior_inpaint.sample()).detach()
mask_resize = torchvision.transforms.Resize([z_inpaint.shape[-2],z_inpaint.shape[-1]])(mask)
test_model_kwargs = {}
test_model_kwargs['inpaint_image'] = z_inpaint
test_model_kwargs['inpaint_mask'] = mask_resize
shape = (model.channels, model.image_size, model.image_size)
# 预测结果
samples, _ = sampler.sample(S=opt.ddim,
batch_size=1,
shape=shape,
pose=hint,
conditioning=reference,
verbose=False,
eta=0,
test_model_kwargs=test_model_kwargs)
samples = 1. / model.scale_factor * samples
x_samples = model.first_stage_model.decode(samples[:,:4,:,:])
x_samples_ddim = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
x_samples_ddim = x_samples_ddim.cpu().permute(0, 2, 3, 1).numpy()
x_checked_image=x_samples_ddim
x_checked_image_torch = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2)
# 保存图像
all_img=[]
all_img_C = []
# all_img.append(un_norm(truth[0]).cpu())
# all_img.append(un_norm(inpaint[0]).cpu())
# all_img.append(un_norm_clip(torchvision.transforms.Resize([512,512])(reference)[0].cpu()))
mask = mask.cpu().permute(0, 2, 3, 1).numpy()
mask = torch.from_numpy(mask).permute(0, 3, 1, 2)
truth = torch.clamp((truth + 1.0) / 2.0, min=0.0, max=1.0)
truth = truth.cpu().permute(0, 2, 3, 1).numpy()
truth = torch.from_numpy(truth).permute(0, 3, 1, 2)
x_checked_image_torch_C = x_checked_image_torch*(1-mask) + truth.cpu()*mask
x_checked_image_torch = torch.nn.functional.interpolate(x_checked_image_torch.float(), size=[512,384])
x_checked_image_torch_C = torch.nn.functional.interpolate(x_checked_image_torch_C.float(), size=[512,384])
all_img.append(x_checked_image_torch[0])
all_img_C.append(x_checked_image_torch_C[0])
grid = torch.stack(all_img, 0)
grid = torchvision.utils.make_grid(grid)
grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
img = Image.fromarray(grid.astype(np.uint8))
img.save("results/Unpaired_Direst/"+str(i)+".png")
grid_C = torch.stack(all_img_C, 0)
grid_C = torchvision.utils.make_grid(grid_C)
grid_C = 255. * rearrange(grid_C, 'c h w -> h w c').cpu().numpy()
img_C = Image.fromarray(grid_C.astype(np.uint8))
img_C.save("results/Unpaired_Concatenation/"+str(i)+".png")