Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add author info replication script in gatys_ecker_betghe_2016 #296

Merged
merged 3 commits into from
Dec 13, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions replication/gatys_ecker_bethge_2016/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,67 @@ def figure_3(args):
)


def figure_2_author_information(args):
images = paper.images()
images.download(args.image_source_dir)

# Gatys provided this information
# compute layer_weights and use relu output -> impl_params=True
hyper_parameters = paper.hyper_parameters(impl_params=True)

hyper_parameters.nst.num_steps = 2000
hyper_parameters.nst.starting_point = "random"

content_image = images["neckarfront"].read(
size=hyper_parameters.nst.image_size, device=args.device
)

class StyleImage:
def __init__(self, label, image, weight_ratio):
self.label = label
self.image = image.read(
size=hyper_parameters.nst.image_size, device=args.device
)
self.weight_ratio = weight_ratio

style_images = ( # weight ratio from paper
StyleImage("B", images["shipwreck"], 1e-3),
StyleImage("C", images["starry_night"], 8e-4),
StyleImage("D", images["the_scream"], 5e-3),
StyleImage("E", images["femme_nue_assise"], 5e-4),
StyleImage("F", images["composition_vii"], 5e-4),
)

for style_image in style_images:
print(f"Replicating Figure 2 {style_image.label} with author information")
# The weight ratio from the paper is between a larger and a smaller value,
# according to the author the larger value was one of the discrete values 1e6 or
# 1e9 (1e9 leads to a Runtime Error)
larger_value = 1e6
hyper_parameters.style_loss.score_weight = larger_value
hyper_parameters.content_loss.score_weight = (
style_image.weight_ratio * larger_value
)

output_image = paper.nst(
content_image,
style_image.image,
hyper_parameters=hyper_parameters,
)

filename = utils.make_output_filename(
"gatys_ecker_betghe_2016",
"fig_2",
style_image.label,
"author",
impl_params=args.impl_params,
)
save_result(
output_image,
path.join(args.image_results_dir, filename),
)


def parse_input():
# TODO: write CLI
image_source_dir = None
Expand Down Expand Up @@ -154,3 +215,4 @@ def process_dir(dir):

figure_2(args)
figure_3(args)
figure_2_author_information(args)