Skip to content

Commit

Permalink
Merge branch 'AUTOMATIC1111:master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
InvincibleDude authored Jan 24, 2023
2 parents 3bc8ee9 + 602a186 commit 44c0e6b
Show file tree
Hide file tree
Showing 39 changed files with 1,198 additions and 675 deletions.
28 changes: 16 additions & 12 deletions extensions-builtin/Lora/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class LoraUpDownModule:
def __init__(self):
self.up = None
self.down = None
self.alpha = None


def assign_lora_names_to_compvis_modules(sd_model):
Expand Down Expand Up @@ -92,6 +93,15 @@ def load_lora(name, filename):
keys_failed_to_match.append(key_diffusers)
continue

lora_module = lora.modules.get(key, None)
if lora_module is None:
lora_module = LoraUpDownModule()
lora.modules[key] = lora_module

if lora_key == "alpha":
lora_module.alpha = weight.item()
continue

if type(sd_module) == torch.nn.Linear:
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
elif type(sd_module) == torch.nn.Conv2d:
Expand All @@ -104,17 +114,12 @@ def load_lora(name, filename):

module.to(device=devices.device, dtype=devices.dtype)

lora_module = lora.modules.get(key, None)
if lora_module is None:
lora_module = LoraUpDownModule()
lora.modules[key] = lora_module

if lora_key == "lora_up.weight":
lora_module.up = module
elif lora_key == "lora_down.weight":
lora_module.down = module
else:
assert False, f'Bad Lora layer name: {key_diffusers} - must end in lora_up.weight or lora_down.weight'
assert False, f'Bad Lora layer name: {key_diffusers} - must end in lora_up.weight, lora_down.weight or alpha'

if len(keys_failed_to_match) > 0:
print(f"Failed to match keys when loading Lora {filename}: {keys_failed_to_match}")
Expand Down Expand Up @@ -161,7 +166,7 @@ def lora_forward(module, input, res):
for lora in loaded_loras:
module = lora.modules.get(lora_layer_name, None)
if module is not None:
res = res + module.up(module.down(input)) * lora.multiplier
res = res + module.up(module.down(input)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)

return res

Expand All @@ -177,12 +182,12 @@ def lora_Conv2d_forward(self, input):
def list_available_loras():
available_loras.clear()

os.makedirs(lora_dir, exist_ok=True)
os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True)

candidates = \
glob.glob(os.path.join(lora_dir, '**/*.pt'), recursive=True) + \
glob.glob(os.path.join(lora_dir, '**/*.safetensors'), recursive=True) + \
glob.glob(os.path.join(lora_dir, '**/*.ckpt'), recursive=True)
glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.pt'), recursive=True) + \
glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.safetensors'), recursive=True) + \
glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.ckpt'), recursive=True)

for filename in sorted(candidates):
if os.path.isdir(filename):
Expand All @@ -193,7 +198,6 @@ def list_available_loras():
available_loras[name] = LoraOnDisk(name, filename)


lora_dir = os.path.join(shared.models_path, "Lora")
available_loras = {}
loaded_loras = []

Expand Down
6 changes: 6 additions & 0 deletions extensions-builtin/Lora/preload.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import os
from modules import paths


def preload(parser):
parser.add_argument("--lora-dir", type=str, help="Path to directory with Lora networks.", default=os.path.join(paths.models_path, 'Lora'))
5 changes: 3 additions & 2 deletions extensions-builtin/Lora/ui_extra_networks_lora.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import os
import lora

Expand Down Expand Up @@ -26,10 +27,10 @@ def list_items(self):
"name": name,
"filename": path,
"preview": preview,
"prompt": f"<lora:{name}:1.0>",
"prompt": json.dumps(f"<lora:{name}:") + " + opts.extra_networks_default_multiplier + " + json.dumps(">"),
"local_preview": path + ".png",
}

def allowed_directories_for_previews(self):
return [lora.lora_dir]
return [shared.cmd_opts.lora_dir]

8 changes: 7 additions & 1 deletion extensions-builtin/SwinIR/scripts/swinir_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from tqdm import tqdm

from modules import modelloader, devices, script_callbacks, shared
from modules.shared import cmd_opts, opts
from modules.shared import cmd_opts, opts, state
from swinir_model_arch import SwinIR as net
from swinir_model_arch_v2 import Swin2SR as net2
from modules.upscaler import Upscaler, UpscalerData
Expand Down Expand Up @@ -145,7 +145,13 @@ def inference(img, model, tile, tile_overlap, window_size, scale):

with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar:
for h_idx in h_idx_list:
if state.interrupted or state.skipped:
break

for w_idx in w_idx_list:
if state.interrupted or state.skipped:
break

in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
out_patch = model(in_patch)
out_patch_mask = torch.ones_like(out_patch)
Expand Down
4 changes: 2 additions & 2 deletions html/extra-networks-card.html
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
<div class='card' {preview_html} onclick='return cardClicked({tabname}, {prompt}, {allow_negative_prompt})'>
<div class='card' {preview_html} onclick={card_clicked}>
<div class='actions'>
<div class='additional'>
<ul>
<a href="#" title="replace preview image with currently selected in gallery" onclick='return saveCardPreview(event, {tabname}, {local_preview})'>replace preview</a>
<a href="#" title="replace preview image with currently selected in gallery" onclick={save_card_preview}>replace preview</a>
</ul>
</div>
<span class='name'>{name}</span>
Expand Down
7 changes: 7 additions & 0 deletions html/image-update.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
110 changes: 66 additions & 44 deletions javascript/edit-attention.js
Original file line number Diff line number Diff line change
@@ -1,74 +1,96 @@
addEventListener('keydown', (event) => {
function keyupEditAttention(event){
let target = event.originalTarget || event.composedPath()[0];
if (!target.matches("[id*='_toprow'] textarea.gr-text-input[placeholder]")) return;
if (! (event.metaKey || event.ctrlKey)) return;


let plus = "ArrowUp"
let minus = "ArrowDown"
if (event.key != plus && event.key != minus) return;
let isPlus = event.key == "ArrowUp"
let isMinus = event.key == "ArrowDown"
if (!isPlus && !isMinus) return;

let selectionStart = target.selectionStart;
let selectionEnd = target.selectionEnd;
// If the user hasn't selected anything, let's select their current parenthesis block
if (selectionStart === selectionEnd) {
let text = target.value;

function selectCurrentParenthesisBlock(OPEN, CLOSE){
if (selectionStart !== selectionEnd) return false;

// Find opening parenthesis around current cursor
const before = target.value.substring(0, selectionStart);
let beforeParen = before.lastIndexOf("(");
if (beforeParen == -1) return;
let beforeParenClose = before.lastIndexOf(")");
const before = text.substring(0, selectionStart);
let beforeParen = before.lastIndexOf(OPEN);
if (beforeParen == -1) return false;
let beforeParenClose = before.lastIndexOf(CLOSE);
while (beforeParenClose !== -1 && beforeParenClose > beforeParen) {
beforeParen = before.lastIndexOf("(", beforeParen - 1);
beforeParenClose = before.lastIndexOf(")", beforeParenClose - 1);
beforeParen = before.lastIndexOf(OPEN, beforeParen - 1);
beforeParenClose = before.lastIndexOf(CLOSE, beforeParenClose - 1);
}

// Find closing parenthesis around current cursor
const after = target.value.substring(selectionStart);
let afterParen = after.indexOf(")");
if (afterParen == -1) return;
let afterParenOpen = after.indexOf("(");
const after = text.substring(selectionStart);
let afterParen = after.indexOf(CLOSE);
if (afterParen == -1) return false;
let afterParenOpen = after.indexOf(OPEN);
while (afterParenOpen !== -1 && afterParen > afterParenOpen) {
afterParen = after.indexOf(")", afterParen + 1);
afterParenOpen = after.indexOf("(", afterParenOpen + 1);
afterParen = after.indexOf(CLOSE, afterParen + 1);
afterParenOpen = after.indexOf(OPEN, afterParenOpen + 1);
}
if (beforeParen === -1 || afterParen === -1) return;
if (beforeParen === -1 || afterParen === -1) return false;

// Set the selection to the text between the parenthesis
const parenContent = target.value.substring(beforeParen + 1, selectionStart + afterParen);
const parenContent = text.substring(beforeParen + 1, selectionStart + afterParen);
const lastColon = parenContent.lastIndexOf(":");
selectionStart = beforeParen + 1;
selectionEnd = selectionStart + lastColon;
target.setSelectionRange(selectionStart, selectionEnd);
}
return true;
}

// If the user hasn't selected anything, let's select their current parenthesis block
if(! selectCurrentParenthesisBlock('<', '>')){
selectCurrentParenthesisBlock('(', ')')
}

event.preventDefault();

if (selectionStart == 0 || target.value[selectionStart - 1] != "(") {
target.value = target.value.slice(0, selectionStart) +
"(" + target.value.slice(selectionStart, selectionEnd) + ":1.0)" +
target.value.slice(selectionEnd);
closeCharacter = ')'
delta = opts.keyedit_precision_attention

if (selectionStart > 0 && text[selectionStart - 1] == '<'){
closeCharacter = '>'
delta = opts.keyedit_precision_extra
} else if (selectionStart == 0 || text[selectionStart - 1] != "(") {

// do not include spaces at the end
while(selectionEnd > selectionStart && text[selectionEnd-1] == ' '){
selectionEnd -= 1;
}
if(selectionStart == selectionEnd){
return
}

target.focus();
target.selectionStart = selectionStart + 1;
target.selectionEnd = selectionEnd + 1;
text = text.slice(0, selectionStart) + "(" + text.slice(selectionStart, selectionEnd) + ":1.0)" + text.slice(selectionEnd);

} else {
end = target.value.slice(selectionEnd + 1).indexOf(")") + 1;
weight = parseFloat(target.value.slice(selectionEnd + 1, selectionEnd + 1 + end));
if (isNaN(weight)) return;
if (event.key == minus) weight -= 0.1;
if (event.key == plus) weight += 0.1;
selectionStart += 1;
selectionEnd += 1;
}

weight = parseFloat(weight.toPrecision(12));
end = text.slice(selectionEnd + 1).indexOf(closeCharacter) + 1;
weight = parseFloat(text.slice(selectionEnd + 1, selectionEnd + 1 + end));
if (isNaN(weight)) return;

target.value = target.value.slice(0, selectionEnd + 1) +
weight +
target.value.slice(selectionEnd + 1 + end - 1);
weight += isPlus ? delta : -delta;
weight = parseFloat(weight.toPrecision(12));
if(String(weight).length == 1) weight += ".0"

target.focus();
target.selectionStart = selectionStart;
target.selectionEnd = selectionEnd;
}
text = text.slice(0, selectionEnd + 1) + weight + text.slice(selectionEnd + 1 + end - 1);

target.focus();
target.value = text;
target.selectionStart = selectionStart;
target.selectionEnd = selectionEnd;

updateInput(target)
});
}

addEventListener('keydown', (event) => {
keyupEditAttention(event);
});
4 changes: 2 additions & 2 deletions javascript/extraNetworks.js
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ function setupExtraNetworksForTab(tabname){
tabs.appendChild(close)

search.addEventListener("input", function(evt){
searchTerm = search.value
searchTerm = search.value.toLowerCase()

gradioApp().querySelectorAll('#'+tabname+'_extra_tabs div.card').forEach(function(elem){
text = elem.querySelector('.name').textContent
text = elem.querySelector('.name').textContent.toLowerCase()
elem.style.display = text.indexOf(searchTerm) == -1 ? "none" : ""
})
});
Expand Down
5 changes: 4 additions & 1 deletion javascript/hints.js
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,10 @@ titles = {
"Hires steps": "Number of sampling steps for upscaled picture. If 0, uses same as for original.",
"Upscale by": "Adjusts the size of the image by multiplying the original width and height by the selected value. Ignored if either Resize width to or Resize height to are non-zero.",
"Resize width to": "Resizes image to this width. If 0, width is inferred from either of two nearby sliders.",
"Resize height to": "Resizes image to this height. If 0, height is inferred from either of two nearby sliders."
"Resize height to": "Resizes image to this height. If 0, height is inferred from either of two nearby sliders.",
"Multiplier for extra networks": "When adding extra network such as Hypernetwork or Lora to prompt, use this multiplier for it.",
"Discard weights with matching name": "Regular expression; if weights's name matches it, the weights is not written to the resulting checkpoint. Use ^model_ema to discard EMA weights.",
"Extra networks tab order": "Comma-separated list of tab names; tabs listed here will appear in the extra networks UI first and in order lsited."
}


Expand Down
5 changes: 0 additions & 5 deletions javascript/ui.js
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,6 @@ function create_tab_index_args(tabId, args){
return res
}

function get_extras_tab_index(){
const [,,...args] = [...arguments]
return [get_tab_index('mode_extras'), get_tab_index('extras_resize_mode'), ...args]
}

function get_img2img_tab_index() {
let res = args_to_array(arguments)
res.splice(-2)
Expand Down
11 changes: 5 additions & 6 deletions launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,16 +179,14 @@ def run_extensions_installers(settings_file):
def prepare_environment():
global skip_install

torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113")
torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 --extra-index-url https://download.pytorch.org/whl/cu117")
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
commandline_args = os.environ.get('COMMANDLINE_ARGS', "")

gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379")
clip_package = os.environ.get('CLIP_PACKAGE', "git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1")
openclip_package = os.environ.get('OPENCLIP_PACKAGE', "git+https://github.com/mlfoundations/open_clip.git@bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b")

xformers_windows_package = os.environ.get('XFORMERS_WINDOWS_PACKAGE', 'https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl')

stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git")
taming_transformers_repo = os.environ.get('TAMING_TRANSFORMERS_REPO', "https://github.com/CompVis/taming-transformers.git")
k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git')
Expand All @@ -210,6 +208,7 @@ def prepare_environment():
sys.argv, _ = extract_arg(sys.argv, '-f')
sys.argv, skip_torch_cuda_test = extract_arg(sys.argv, '--skip-torch-cuda-test')
sys.argv, reinstall_xformers = extract_arg(sys.argv, '--reinstall-xformers')
sys.argv, reinstall_torch = extract_arg(sys.argv, '--reinstall-torch')
sys.argv, update_check = extract_arg(sys.argv, '--update-check')
sys.argv, run_tests, test_dir = extract_opt(sys.argv, '--tests')
sys.argv, skip_install = extract_arg(sys.argv, '--skip-install')
Expand All @@ -221,7 +220,7 @@ def prepare_environment():
print(f"Python {sys.version}")
print(f"Commit hash: {commit}")

if not is_installed("torch") or not is_installed("torchvision"):
if reinstall_torch or not is_installed("torch") or not is_installed("torchvision"):
run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch")

if not skip_torch_cuda_test:
Expand All @@ -239,14 +238,14 @@ def prepare_environment():
if (not is_installed("xformers") or reinstall_xformers) and xformers:
if platform.system() == "Windows":
if platform.python_version().startswith("3.10"):
run_pip(f"install -U -I --no-deps {xformers_windows_package}", "xformers")
run_pip(f"install -U -I --no-deps xformers==0.0.16rc425", "xformers")
else:
print("Installation of xformers is not supported in this version of Python.")
print("You can also check this and build manually: https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers#building-xformers-on-windows-by-duckness")
if not is_installed("xformers"):
exit(0)
elif platform.system() == "Linux":
run_pip("install xformers", "xformers")
run_pip("install xformers==0.0.16rc425", "xformers")

if not is_installed("pyngrok") and ngrok:
run_pip("install pyngrok", "ngrok")
Expand Down
Loading

0 comments on commit 44c0e6b

Please sign in to comment.