-
Notifications
You must be signed in to change notification settings - Fork 24
/
Copy pathgeneration_evaluation.py
54 lines (50 loc) · 2.35 KB
/
generation_evaluation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
'''
This code is used to evaluate the FID score of the generated images.
You should at least guarantee this code can run without any error on test set.
And whether this code can run is the most important factor for grading.
We provide the remaining code, you can't modify the remaining code, all you should do are:
1. Modify the sample function to get the generated images from the model and ensure the generated images are saved to the gen_data_dir(line 12-18)
2. Modify how you call your sample function(line 31)
'''
from pytorch_fid.fid_score import calculate_fid_given_paths
from utils import *
from model import *
from dataset import *
import os
import torch
# You should modify this sample function to get the generated images from your model
# This function should save the generated images to the gen_data_dir, which is fixed as 'samples'
# Begin of your code
sample_op = lambda x : sample_from_discretized_mix_logistic(x, 5)
def my_sample(model, gen_data_dir, sample_batch_size = 25, obs = (3,32,32), sample_op = sample_op):
for label in my_bidict:
print(f"Label: {label}")
#generate images for each label, each label has 25 images
sample_t = sample(model, sample_batch_size, obs, sample_op)
sample_t = rescaling_inv(sample_t)
save_images(sample_t, os.path.join(gen_data_dir), label=label)
pass
# End of your code
if __name__ == "__main__":
ref_data_dir = "data/test"
gen_data_dir = "samples"
BATCH_SIZE=128
device = "cuda" if torch.cuda.is_available() else "cpu"
if not os.path.exists(gen_data_dir):
os.makedirs(gen_data_dir)
#Begin of your code
#Load your model and generate images in the gen_data_dir
model = PixelCNN(nr_resnet=1, nr_filters=40, input_channels=3, nr_logistic_mix=5)
model = model.to(device)
model = model.eval()
my_sample(model=model, gen_data_dir=gen_data_dir)
#End of your code
paths = [gen_data_dir, ref_data_dir]
print("#generated images: {:d}, #reference images: {:d}".format(
len(os.listdir(gen_data_dir)), len(os.listdir(ref_data_dir))))
try:
fid_score = calculate_fid_given_paths(paths, BATCH_SIZE, device, dims=192)
print("Dimension {:d} works! fid score: {}".format(192, fid_score, gen_data_dir))
except:
print("Dimension {:d} fails!".format(192))
print("Average fid score: {}".format(fid_score))