Skip to content

Commit

Permalink
Add author info replication script in gatys_ecker_betghe_2016 (#296)
Browse files Browse the repository at this point in the history
* Add author info replication script

* lint

* Update comment and remove impl_param in nst() function
  • Loading branch information
jbueltemeier authored Dec 13, 2022
1 parent 568c2fa commit a8a87f4
Showing 1 changed file with 62 additions and 0 deletions.
62 changes: 62 additions & 0 deletions replication/gatys_ecker_bethge_2016/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,67 @@ def figure_3(args):
)


def figure_2_author_information(args):
images = paper.images()
images.download(args.image_source_dir)

# Gatys provided this information
# compute layer_weights and use relu output -> impl_params=True
hyper_parameters = paper.hyper_parameters(impl_params=True)

hyper_parameters.nst.num_steps = 2000
hyper_parameters.nst.starting_point = "random"

content_image = images["neckarfront"].read(
size=hyper_parameters.nst.image_size, device=args.device
)

class StyleImage:
def __init__(self, label, image, weight_ratio):
self.label = label
self.image = image.read(
size=hyper_parameters.nst.image_size, device=args.device
)
self.weight_ratio = weight_ratio

style_images = ( # weight ratio from paper
StyleImage("B", images["shipwreck"], 1e-3),
StyleImage("C", images["starry_night"], 8e-4),
StyleImage("D", images["the_scream"], 5e-3),
StyleImage("E", images["femme_nue_assise"], 5e-4),
StyleImage("F", images["composition_vii"], 5e-4),
)

for style_image in style_images:
print(f"Replicating Figure 2 {style_image.label} with author information")
# The weight ratio from the paper is between a larger and a smaller value,
# according to the author the larger value was one of the discrete values 1e6 or
# 1e9 (1e9 leads to a Runtime Error)
larger_value = 1e6
hyper_parameters.style_loss.score_weight = larger_value
hyper_parameters.content_loss.score_weight = (
style_image.weight_ratio * larger_value
)

output_image = paper.nst(
content_image,
style_image.image,
hyper_parameters=hyper_parameters,
)

filename = utils.make_output_filename(
"gatys_ecker_betghe_2016",
"fig_2",
style_image.label,
"author",
impl_params=args.impl_params,
)
save_result(
output_image,
path.join(args.image_results_dir, filename),
)


def parse_input():
# TODO: write CLI
image_source_dir = None
Expand Down Expand Up @@ -154,3 +215,4 @@ def process_dir(dir):

figure_2(args)
figure_3(args)
figure_2_author_information(args)

0 comments on commit a8a87f4

Please sign in to comment.