Skip to content

Commit

Permalink
Merge pull request AUTOMATIC1111#1 from ShinkoNet/master
Browse files Browse the repository at this point in the history
Memory Patch
  • Loading branch information
hlky authored Aug 24, 2022
2 parents 95423e2 + 60c8177 commit 0ed0dd7
Show file tree
Hide file tree
Showing 2 changed files with 190 additions and 67 deletions.
11 changes: 11 additions & 0 deletions relauncher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import os, time

n = 0
while True:
print('Relauncher: Launching...')
if n > 0:
print(f'\tRelaunch count: {n}')
os.system("python scripts/webui.py")
print('Relauncher: Process ending. Relaunching in 0.5s...')
n += 1
time.sleep(0.5)
246 changes: 179 additions & 67 deletions webui.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,25 @@
import argparse, os, sys, glob
import torch
import torch.nn as nn
import numpy as np
import gradio as gr
from omegaconf import OmegaConf
from PIL import Image, ImageFont, ImageDraw
from itertools import islice
from einops import rearrange, repeat
from torch import autocast
from contextlib import contextmanager, nullcontext
import mimetypes
import random
import k_diffusion as K
import math
import mimetypes
import numpy as np
import pynvml
import random
import threading
import time
import torch
import torch.nn as nn

import k_diffusion as K
from ldm.util import instantiate_from_config
from contextlib import contextmanager, nullcontext
from einops import rearrange, repeat
from itertools import islice
from omegaconf import OmegaConf
from PIL import Image, ImageFont, ImageDraw
from torch import autocast
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
from ldm.util import instantiate_from_config

try:
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
Expand Down Expand Up @@ -87,6 +87,50 @@ def load_model_from_config(config, ckpt, verbose=False):
model.eval()
return model

def crash(e, s):
global model
global device

print(s, '\n', e)

del model
del device

print('exiting...calling os._exit(0)')
t = threading.Timer(0.25, os._exit, args=[0])
t.start()

class MemUsageMonitor(threading.Thread):
stop_flag = False
max_usage = 0
total = 0

def __init__(self, name):
threading.Thread.__init__(self)
self.name = name

def run(self):
print(f"[{self.name}] Recording max memory usage...\n")
pynvml.nvmlInit()
handle = pynvml.nvmlDeviceGetHandleByIndex(0)
self.total = pynvml.nvmlDeviceGetMemoryInfo(handle).total
while not self.stop_flag:
m = pynvml.nvmlDeviceGetMemoryInfo(handle)
self.max_usage = max(self.max_usage, m.used)
# print(self.max_usage)
time.sleep(0.1)
print(f"[{self.name}] Stopped recording.\n")
pynvml.nvmlShutdown()

def read(self):
return self.max_usage, self.total

def stop(self):
self.stop_flag = True

def read_and_stop(self):
self.stop_flag = True
return self.max_usage, self.total

class CFGDenoiser(nn.Module):
def __init__(self, model):
Expand Down Expand Up @@ -389,8 +433,10 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name,

precision_scope = autocast if opt.precision == "autocast" else nullcontext
output_images = []
stats = []
with torch.no_grad(), precision_scope("cuda"), model.ema_scope():
init_data = func_init()
tic = time.time()

for n in range(n_iter):
prompts = all_prompts[n * batch_size:(n + 1) * batch_size]
Expand Down Expand Up @@ -432,7 +478,6 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name,
grid = image_grid(output_images, batch_size, round_down=prompt_matrix)

if prompt_matrix:

try:
grid = draw_prompt_matrix(grid, width, height, prompt_matrix_parts)
except Exception:
Expand All @@ -442,31 +487,38 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name,

output_images.insert(0, grid)


grid_file = f"grid-{grid_count:05}-{seed}_{prompts[i].replace(' ', '_').translate({ord(x): '' for x in invalid_filename_chars})[:128]}.jpg"
grid.save(os.path.join(outpath, grid_file), 'jpeg', quality=80, optimize=True)
grid_count += 1
toc = time.time()

mem_max_used, mem_total = mem_mon.read_and_stop()
time_diff = time.time()-start_time

notes = f'''
Took { round(time_diff, 2) }s total ({ round(time_diff/(batch_size*n_iter),2) }s per image)<br>
Peak memory usage: { -(mem_max_used // -1_048_576) } MiB / { -(mem_total // -1_048_576) } MiB / { round(mem_max_used/mem_total*100, 3) }%<br>
'''
mem_max_used, mem_total = mem_mon.read_and_stop()
time_diff = time.time()-start_time

info = f"""
{prompt}
Steps: {steps}, Sampler: {sampler_name}, CFG scale: {cfg_scale}, Seed: {seed}{', GFPGAN' if use_GFPGAN and GFPGAN is not None else ''}
""".strip()

Steps: {steps}, Sampler: {sampler_name}, CFG scale: {cfg_scale}, Seed: {seed}{', GFPGAN' if use_GFPGAN and GFPGAN is not None else ''}{', Prompt Matrix Mode.' if prompt_matrix else ''}""".strip()
stats = f'''
Took { round(time_diff, 2) }s total ({ round(time_diff/(len(all_prompts)),2) }s per image)
Peak memory usage: { -(mem_max_used // -1_048_576) } MiB / { -(mem_total // -1_048_576) } MiB / { round(mem_max_used/mem_total*100, 3) }%'''

for comment in comments:
info += "\n\n" + comment

#mem_mon.stop()
#del mem_mon
torch_gc()
return output_images, seed, info, notes

return output_images, seed, info, stats


def txt2img(prompt: str, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, prompt_matrix: bool, ddim_eta: float, n_iter: int, batch_size: int, cfg_scale: float, seed: int, height: int, width: int):
outpath = opt.outdir or "outputs/txt2img-samples"
err = False

if sampler_name == 'PLMS':
sampler = PLMSSampler(model)
Expand All @@ -483,27 +535,35 @@ def init():
def sample(init_data, x, conditioning, unconditional_conditioning):
samples_ddim, _ = sampler.sample(S=ddim_steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=cfg_scale, unconditional_conditioning=unconditional_conditioning, eta=ddim_eta, x_T=x)
return samples_ddim

output_images, seed, info, notes = process_images(
outpath=outpath,
func_init=init,
func_sample=sample,
prompt=prompt,
seed=seed,
sampler_name=sampler_name,
batch_size=batch_size,
n_iter=n_iter,
steps=ddim_steps,
cfg_scale=cfg_scale,
width=width,
height=height,
prompt_matrix=prompt_matrix,
use_GFPGAN=use_GFPGAN
)

del sampler

return output_images, seed, info, notes
try:
output_images, seed, info, stats = process_images(
outpath=outpath,
func_init=init,
func_sample=sample,
prompt=prompt,
seed=seed,
sampler_name=sampler_name,
batch_size=batch_size,
n_iter=n_iter,
steps=ddim_steps,
cfg_scale=cfg_scale,
width=width,
height=height,
prompt_matrix=prompt_matrix,
use_GFPGAN=use_GFPGAN
)

del sampler

return output_images, seed, info, stats
except RuntimeError as e:
err = e
err_msg = f'CRASHED:<br><textarea rows="5" style="background: black;width: -webkit-fill-available;font-family: monospace;font-size: small;font-weight: bold;">{str(e)}</textarea><br><br>Please wait while the program restarts.'
stats = err_msg
return [], 1
finally:
if err:
crash(err, '!!Runtime error (txt2img)!!')


class Flagging(gr.FlaggingCallback):
Expand Down Expand Up @@ -567,16 +627,17 @@ def flag(self, flag_data, flag_option=None, flag_index=None, username=None):
gr.Gallery(label="Images"),
gr.Number(label='Seed'),
gr.Textbox(label="Copy-paste generation parameters"),
gr.HTML(label='Notes'),
gr.HTML(label='Stats'),
],
title="Stable Diffusion Text-to-Image K",
description="Generate images from text with Stable Diffusion (using K-LMS)",
title="Stable Diffusion Text-to-Image Unified",
description="Generate images from text with Stable Diffusion",
flagging_callback=Flagging()
)


def img2img(prompt: str, init_img, ddim_steps: int, use_GFPGAN: bool, prompt_matrix, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int):
outpath = opt.outdir or "outputs/img2img-samples"
err = False

sampler = KDiffusionSampler(model)

Expand Down Expand Up @@ -609,26 +670,77 @@ def sample(init_data, x, conditioning, unconditional_conditioning):
samples_ddim = K.sampling.sample_lms(model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': cfg_scale}, disable=False)
return samples_ddim

output_images, seed, info, notes = process_images(
outpath=outpath,
func_init=init,
func_sample=sample,
prompt=prompt,
seed=seed,
sampler_name='k-diffusion',
batch_size=batch_size,
n_iter=n_iter,
steps=ddim_steps,
cfg_scale=cfg_scale,
width=width,
height=height,
prompt_matrix=prompt_matrix,
use_GFPGAN=use_GFPGAN
)

del sampler
try:
if loopback:
output_images, info = None, None
history = []
initial_seed = None

for i in range(n_iter):
output_images, seed, info, stats = process_images(
outpath=outpath,
func_init=init,
func_sample=sample,
prompt=prompt,
seed=seed,
sampler_name='k-diffusion',
batch_size=1,
n_iter=1,
steps=ddim_steps,
cfg_scale=cfg_scale,
width=width,
height=height,
prompt_matrix=prompt_matrix,
use_GFPGAN=use_GFPGAN,
do_not_save_grid=True
)

if initial_seed is None:
initial_seed = seed

init_img = output_images[0]
seed = seed + 1
denoising_strength = max(denoising_strength * 0.95, 0.1)
history.append(init_img)

grid_count = len(os.listdir(outpath)) - 1
grid = image_grid(history, batch_size, force_n_rows=1)
grid.save(os.path.join(outpath, f'grid-{grid_count:04}.{opt.grid_format}'))

output_images = history
seed = initial_seed

else:
output_images, seed, info, stats = process_images(
outpath=outpath,
func_init=init,
func_sample=sample,
prompt=prompt,
seed=seed,
sampler_name='k-diffusion',
batch_size=batch_size,
n_iter=n_iter,
steps=ddim_steps,
cfg_scale=cfg_scale,
width=width,
height=height,
prompt_matrix=prompt_matrix,
use_GFPGAN=use_GFPGAN
)

del sampler

return output_images, seed, info, stats
except RuntimeError as e:
err = e
err_msg = f'CRASHED:<br><textarea rows="5" style="background: black;width: -webkit-fill-available;font-family: monospace;font-size: small;font-weight: bold;">{str(e)}</textarea><br><br>Please wait while the program restarts.'
stats = err_msg
return [], 1
finally:
if err:
crash(err, '!!Runtime error (img2img)!!')

return output_images, seed, info, notes


sample_img2img = "assets/stable-samples/img2img/sketch-mountains-input.jpg"
Expand All @@ -655,9 +767,9 @@ def sample(init_data, x, conditioning, unconditional_conditioning):
gr.Gallery(),
gr.Number(label='Seed'),
gr.Textbox(label="Copy-paste generation parameters"),
gr.HTML(label='Notes'),
gr.HTML(label='Stats'),
],
title="Stable Diffusion Image-to-Image",
title="Stable Diffusion Image-to-Image Unified",
description="Generate images from images with Stable Diffusion",
allow_flagging="never",
)
Expand Down Expand Up @@ -700,4 +812,4 @@ def run_GFPGAN(image, strength):
css=("" if opt.no_progressbar_hiding else css_hide_progressbar)
)

demo.launch()
demo.launch()

0 comments on commit 0ed0dd7

Please sign in to comment.