Skip to content

Commit

Permalink
Merge pull request AUTOMATIC1111#1 from hnmr293/lowvram
Browse files Browse the repository at this point in the history
add lowvram option
  • Loading branch information
Mikubill authored Feb 13, 2023
2 parents 644e2a2 + abfd626 commit f305003
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 9 deletions.
28 changes: 23 additions & 5 deletions scripts/cldm.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,15 @@ def get_state_dict(d):


class PlugableControlModel(nn.Module):
def __init__(self, model_path, config_path, weight=1.0) -> None:
def __init__(self, model_path, config_path, weight=1.0, lowvram=False) -> None:
super().__init__()
config = OmegaConf.load(config_path)
self.control_model = ControlNet(**config.model.params.control_stage_config.params).cuda()
state_dict = load_state_dict(model_path, location='cuda')
if lowvram:
self.control_model = ControlNet(**config.model.params.control_stage_config.params).cpu()
state_dict = load_state_dict(model_path, location='cpu')
else:
self.control_model = ControlNet(**config.model.params.control_stage_config.params).cuda()
state_dict = load_state_dict(model_path, location='cuda')
if any([k.startswith("control_model.") for k, v in state_dict.items()]):
state_dict = {k.replace("control_model.", ""): v for k, v in state_dict.items() if k.startswith("control_model.")}

Expand All @@ -48,8 +52,9 @@ def __init__(self, model_path, config_path, weight=1.0) -> None:
self.only_mid_control = False
self.control = None
self.hint_cond = None
self.lowvram = lowvram

def hook(self, model):
def hook(self, model, parent_model):
outer = self

def forward(self, x, timesteps=None, context=None, **kwargs):
Expand Down Expand Up @@ -81,8 +86,21 @@ def forward(self, x, timesteps=None, context=None, **kwargs):
h = h.type(x.dtype)
return self.out(h)

def forward2(*args, **kwargs):
try:
if self.lowvram:
parent_model.first_stage_model.cpu()
parent_model.cond_stage_model.cpu()
self.control_model.cuda()
return forward(*args, **kwargs)
finally:
if self.lowvram:
self.control_model.cpu()
parent_model.first_stage_model.cuda()
parent_model.cond_stage_model.cuda()

model._original_forward = model.forward
model.forward = forward.__get__(model, UNetModel)
model.forward = forward2.__get__(model, UNetModel)

def notify(self, cond_like):
self.hint_cond = cond_like
Expand Down
11 changes: 7 additions & 4 deletions scripts/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def ui(self, is_img2img):
with gr.Row():
enabled = gr.Checkbox(label='Enable', value=False)
scribble_mode = gr.Checkbox(label='Scibble Mode (Reverse color)', value=False)
lowvram = gr.Checkbox(label='Low VRAM (8GB or below)', value=False)

ctrls += (enabled,)
self.infotext_fields.append((enabled, "ControlNet Enabled"))
Expand Down Expand Up @@ -186,6 +187,7 @@ def create_canvas(h, w):

create_button.click(fn=create_canvas, inputs=[canvas_height, canvas_width], outputs=[input_image])
ctrls += (input_image, scribble_mode, resize_mode)
ctrls += (lowvram,)

return ctrls

Expand Down Expand Up @@ -215,14 +217,15 @@ def restore_networks():
self.latest_network.restore(unet)
self.latest_network = None

enabled, module, model, weight,image, scribble_mode, resize_mode = args
enabled, module, model, weight,image, scribble_mode, resize_mode, lowvram = args

if not enabled:
restore_networks()
return

models_changed = self.latest_params[0] != module or self.latest_params[1] != model \
or self.latest_model_hash != p.sd_model.sd_model_hash or self.latest_network == None
or self.latest_model_hash != p.sd_model.sd_model_hash or self.latest_network == None \
or (self.latest_network is not None and self.latest_network.lowvram != lowvram)

if models_changed:
restore_networks()
Expand All @@ -241,9 +244,9 @@ def restore_networks():
raise ValueError(f"file not found: {model_path}")

print(f"loading preprocessor: {module}, model: {model}")
network = PlugableControlModel(model_path, os.path.join(cn_models_dir, "cldm_v15.yaml"), weight)
network = PlugableControlModel(model_path, os.path.join(cn_models_dir, "cldm_v15.yaml"), weight, lowvram=lowvram)
network.to(p.sd_model.device, dtype=p.sd_model.dtype)
network.hook(unet)
network.hook(unet, p.sd_model)

print(f"ControlNet model {model} loaded.")
self.latest_network = network
Expand Down

0 comments on commit f305003

Please sign in to comment.