-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
187 lines (153 loc) · 8.4 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
import os
import json
import random
from datetime import datetime
from time import time
import clip
import pydiffvg
from adam import main_adam
from cmaes import main_cma_es
from utils import create_save_folder, get_active_models_from_arg, open_class_mapping, \
get_class_index_list
os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
import numpy as np
import torch
import argparse
from config import *
from render import *
from fitnesses import *
render_table = {
"chars": CharsRenderer,
"pylinhas": PylinhasRenderer,
"organic": OrganicRenderer,
"thinorg": ThinOrganicRenderer,
"pixeldraw": PixelRenderer,
"fastpixel": FastPixelRenderer,
"vqgan": VQGANRenderer,
"clipdraw": ClipDrawRenderer,
"vdiff": VDiffRenderer,
"biggan": BigGANRenderer,
"linedraw": LineDrawRenderer,
"fftdraw": FFTRenderer,
}
def setup_args():
parser = argparse.ArgumentParser(description="Evolve to objective")
parser.add_argument('--evolution-type', default=EVOLUTION_TYPE, help='Specify the type of evolution. (cmaes or adam). Default is {}.'.format(EVOLUTION_TYPE))
parser.add_argument('--random-seed', default=RANDOM_SEED, type=int, help='Use a specific random seed (for repeatability). Default is {}.'.format(RANDOM_SEED))
parser.add_argument('--save-folder', default=SAVE_FOLDER, help="Directory to experiment outputs. Default is {}.".format(SAVE_FOLDER))
parser.add_argument('--n-gens', default=N_GENS, type=int, help='Maximum generations. Default is {}.'.format(N_GENS))
parser.add_argument('--pop-size', default=POP_SIZE, type=int, help='Population size. Default is {}.'.format(POP_SIZE))
parser.add_argument('--save-all', default=SAVE_ALL, action='store_true', help='Save all Individual images. Default is {}.'.format(SAVE_ALL))
parser.add_argument('--checkpoint-freq', default=CHECKPOINT_FREQ, type=int, help='Checkpoint save frequency. Default is {}.'.format(CHECKPOINT_FREQ))
parser.add_argument('--verbose', default=VERBOSE, action='store_true', help='Verbose. Default is {}.'.format(VERBOSE))
parser.add_argument('--num-lines', default=NUM_LINES, type=int, help="Number of lines. Default is {}".format(NUM_LINES))
parser.add_argument('--renderer-type', default=RENDERER, help="Choose the renderer. Default is {}".format(RENDERER))
parser.add_argument('--img-size', default=IMG_SIZE, type=int, help='Image dimensions during testing. Default is {}.'.format(IMG_SIZE))
parser.add_argument('--target-class', default=TARGET_CLASS, help='Which target classes to optimize. Default is {}.'.format(TARGET_CLASS))
parser.add_argument("--networks", default=NETWORKS, help="comma separated list of networks (no spaces). Default is {}.".format(NETWORKS))
parser.add_argument('--target-fit', default=TARGET_FITNESS, type=float, help='target fitness stopping criteria. Default is {}.'.format(TARGET_FITNESS))
parser.add_argument('--from-checkpoint', default=FROM_CHECKPOINT, help='Checkpoint file from which you want to continue evolving. Default is {}.'.format(FROM_CHECKPOINT))
parser.add_argument('--sigma', default=SIGMA, type=float, help='The initial standard deviation of the distribution. Default is {}.'.format(SIGMA))
parser.add_argument('--clip-model', default=CLIP_MODEL, help='Name of the CLIP model to use. Default is {}. Availables: {}'.format(CLIP_MODEL, clip.available_models()))
parser.add_argument('--clip-prompts', default=None, help='CLIP prompts to use for the generation. Default is the target class')
parser.add_argument('--input-image', default=None, help='Image to use as input.')
parser.add_argument('--adam-steps', default=ADAM_STEPS, type=int, help='Number of steps from Adam. Default is {}.'.format(ADAM_STEPS))
parser.add_argument('--lr', default=LR, type=float, help='Learning rate for the Adam optimizer. Default is {}.'.format(LR))
parser.add_argument('--lamarck', default=LAMARCK, action='store_true', help='Lamarck. Default is {}.'.format(LAMARCK))
args = parser.parse_args()
# args.clip_prompts = "A galleon stranded on a sea of flowers, sunset, dusk, light effect, by Marc Simonetti"
# args.input_image = "a.png"
if args.from_checkpoint:
args.experiment_name = args.from_checkpoint.replace("_checkpoint.pkl", "")
# save_folder = f"experiments/{experiment_name}"
# CHECKPOINT = f"{save_folder}/{CHECKPOINT}"
# save_folder = "{}/{}".format(save_folder, experiment_name)
args.sub_folder = "from_checkpoint"
save_folder, sub_folder = create_save_folder(args.save_folder, args.sub_folder)
args.checkpoint = "{}/{}".format(save_folder, args.from_checkpoint)
else:
if args.clip_prompts:
prompt = args.clip_prompts.replace(" ", "_")
elif args.input_image:
prompt = args.input_image
else:
prompt = args.target_class
args.experiment_name = f"{args.renderer_type}_L{args.num_lines}_{prompt}_{args.random_seed if args.random_seed else datetime.now().strftime('%Y-%m-%d_%H-%M')}"
args.sub_folder = f"{args.experiment_name}_{args.n_gens}_{args.pop_size}"
save_folder, sub_folder = create_save_folder(args.save_folder, args.sub_folder)
args.checkpoint = "{}/{}".format(save_folder, args.from_checkpoint)
args_dict = vars(args)
with open(f"{args.save_folder}/{args.sub_folder}/config.json", 'w') as f:
json.dump(args_dict, f)
class_mapping = open_class_mapping()
if args.target_class is None or args.target_class == "none":
args.imagenet_indexes = None
else:
args.imagenet_indexes = get_class_index_list(class_mapping, args.target_class)
if args.random_seed:
print("Setting random seed: ", args.random_seed)
random.seed(args.random_seed)
np.random.seed(args.random_seed)
torch.manual_seed(args.random_seed)
args.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Using device:", args.device)
# Use GPU if available
pydiffvg.set_use_gpu(torch.cuda.is_available())
pydiffvg.set_device(args.device)
args.normalization = (args.renderer_type == "biggan")
args.renderer = render_table[args.renderer_type](args)
"""
args.active_models = get_active_models_from_arg(args.networks)
args.active_models_quantity = len(args.active_models.keys())
print("Loaded models:")
for key, value in args.active_models.items():
print("- ", key)
"""
if args.clip_model not in clip.available_models():
args.clip_model = "ViT-B/32"
print(f"Loading CLIP model: {args.clip_model}")
model, preprocess = clip.load(args.clip_model, device=args.device)
args.clip = model
args.preprocess = preprocess
print("CLIP module loaded.")
args.fitnesses = []
if args.clip_prompts:
args.fitnesses.append(ClipPrompt(args.clip_prompts, model=args.clip, preprocess=args.preprocess))
if args.input_image:
args.fitnesses.append(InputImage(args.input_image, model=args.clip, preprocess=args.preprocess))
# args.fitnesses.append(PaletteFitness(palette=[[0/255.0, 0/255.0, 0/255.0], [255/255.0, 241/255.0, 232/255.0]]))
# args.fitnesses.append(AestheticFitness())
# args.fitnesses.append(Aesthetic2Fitness())
# args.fitnesses.append(GaussianFitness())
# args.fitnesses.append(ResmemFitness())
# args.fitnesses.append(SaturationFitness())
# args.fitnesses.append(SmoothnessFitness())
# args.fitnesses.append(SymmetryFitness())
# args.fitnesses.append(StyleFitness(style_file="style.png"))
# args.fitnesses.append(EdgeFitness())
if args.pop_size <= 1:
print(f"Population size as {args.pop_size}, changing to Adam.")
args.evolution_type = "adam"
return args
if __name__ == "__main__":
# Get time of start of the program
start_time_total = time()
# Get arguments
args = setup_args()
# Get time of start of the evolution
start_time_evo = time()
# Main program
if args.evolution_type == "adam":
main_adam(args)
elif args.evolution_type == "cmaes":
main_cma_es(args)
else:
print("The used evolution mode is not defined. Please choose one of the following (\"cmaes\", \"adam\")")
# Get end time
end_time = time()
evo_time = (end_time - start_time_evo)
total_time = (end_time - start_time_total)
print("-" * 20)
print("Evolution elapsed time: {:.3f}".format(evo_time))
print("Total elapsed time: {:.3f}".format(total_time))
print("-" * 20)