Skip to content

Commit

Permalink
make it possible to merge inpainting model with non-inpainting one
Browse files Browse the repository at this point in the history
  • Loading branch information
AUTOMATIC1111 committed Dec 4, 2022
1 parent 8504db5 commit 44c46f0
Showing 1 changed file with 25 additions and 2 deletions.
27 changes: 25 additions & 2 deletions modules/extras.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ def add_difference(theta0, theta1_2_diff, alpha):
primary_model_info = sd_models.checkpoints_list[primary_model_name]
secondary_model_info = sd_models.checkpoints_list[secondary_model_name]
teritary_model_info = sd_models.checkpoints_list.get(teritary_model_name, None)
result_is_inpainting_model = False

print(f"Loading {primary_model_info.filename}...")
theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu')
Expand Down Expand Up @@ -280,8 +281,22 @@ def add_difference(theta0, theta1_2_diff, alpha):

for key in tqdm.tqdm(theta_0.keys()):
if 'model' in key and key in theta_1:
a = theta_0[key]
b = theta_1[key]

theta_0[key] = theta_func2(theta_0[key], theta_1[key], multiplier)
# this enables merging an inpainting model (A) with another one (B);
# where normal model would have 4 channels, for latenst space, inpainting model would
# have another 4 channels for unmasked picture's latent space, plus one channel for mask, for a total of 9
if a.shape != b.shape and a.shape[0:1] + a.shape[2:] == b.shape[0:1] + b.shape[2:]:
if a.shape[1] == 4 and b.shape[1] == 9:
raise RuntimeError("When merging inpainting model with a normal one, A must be the inpainting model.")

assert a.shape[1] == 9 and b.shape[1] == 4, f"Bad dimensions for merged layer {key}: A={a.shape}, B={b.shape}"

theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier)
result_is_inpainting_model = True
else:
theta_0[key] = theta_func2(a, b, multiplier)

if save_as_half:
theta_0[key] = theta_0[key].half()
Expand All @@ -295,8 +310,16 @@ def add_difference(theta0, theta1_2_diff, alpha):

ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path

filename = primary_model_info.model_name + '_' + str(round(1-multiplier, 2)) + '-' + secondary_model_info.model_name + '_' + str(round(multiplier, 2)) + '-' + interp_method.replace(" ", "_") + '-merged.' + checkpoint_format
filename = \
primary_model_info.model_name + '_' + str(round(1-multiplier, 2)) + '-' + \
secondary_model_info.model_name + '_' + str(round(multiplier, 2)) + '-' + \
interp_method.replace(" ", "_") + \
'-merged.' + \
("inpainting." if result_is_inpainting_model else "") + \
checkpoint_format

filename = filename if custom_name == '' else (custom_name + '.' + checkpoint_format)

output_modelname = os.path.join(ckpt_dir, filename)

print(f"Saving to {output_modelname}...")
Expand Down

1 comment on commit 44c46f0

@alfiedennen
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Im getting a size mismatch in pytorch when merging with 1.5 pruned EMAonly.

Please sign in to comment.