Skip to content

Commit

Permalink
pynvml mem usage
Browse files Browse the repository at this point in the history
pynvml mem usage
  • Loading branch information
hlky committed Aug 24, 2022
1 parent 47d79b2 commit 95423e2
Showing 1 changed file with 54 additions and 6 deletions.
60 changes: 54 additions & 6 deletions webui.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
import mimetypes
import random
import math
import pynvml
import threading
import time

import k_diffusion as K
from ldm.util import instantiate_from_config
Expand Down Expand Up @@ -112,6 +115,37 @@ def sample(self, S, conditioning, batch_size, shape, verbose, unconditional_guid

return samples_ddim, None

class MemUsageMonitor(threading.Thread):
stop_flag = False
max_usage = 0
total = 0

def __init__(self, name):
threading.Thread.__init__(self)
self.name = name

def run(self):
print(f"[{self.name}] Recording max memory usage...\n")
pynvml.nvmlInit()
handle = pynvml.nvmlDeviceGetHandleByIndex(0)
self.total = pynvml.nvmlDeviceGetMemoryInfo(handle).total
while not self.stop_flag:
m = pynvml.nvmlDeviceGetMemoryInfo(handle)
self.max_usage = max(self.max_usage, m.used)
# print(self.max_usage)
time.sleep(0.1)
print(f"[{self.name}] Stopped recording.\n")
pynvml.nvmlShutdown()

def read(self):
return self.max_usage, self.total

def stop(self):
self.stop_flag = True

def read_and_stop(self):
self.stop_flag = True
return self.max_usage, self.total

def create_random_tensors(shape, seeds):
xs = []
Expand Down Expand Up @@ -301,9 +335,13 @@ def check_prompt_length(prompt, comments):

def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name, batch_size, n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN):
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""

assert prompt is not None
torch_gc()
# start time after garbage collection (or before?)
start_time = time.time()

mem_mon = MemUsageMonitor('MemMon')
mem_mon.start()

if seed == -1:
seed = random.randrange(4294967294)
Expand Down Expand Up @@ -408,6 +446,14 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name,
grid.save(os.path.join(outpath, grid_file), 'jpeg', quality=80, optimize=True)
grid_count += 1

mem_max_used, mem_total = mem_mon.read_and_stop()
time_diff = time.time()-start_time

notes = f'''
Took { round(time_diff, 2) }s total ({ round(time_diff/(batch_size*n_iter),2) }s per image)<br>
Peak memory usage: { -(mem_max_used // -1_048_576) } MiB / { -(mem_total // -1_048_576) } MiB / { round(mem_max_used/mem_total*100, 3) }%<br>
'''

info = f"""
{prompt}
Steps: {steps}, Sampler: {sampler_name}, CFG scale: {cfg_scale}, Seed: {seed}{', GFPGAN' if use_GFPGAN and GFPGAN is not None else ''}
Expand All @@ -416,7 +462,7 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name,
for comment in comments:
info += "\n\n" + comment
torch_gc()
return output_images, seed, info
return output_images, seed, info, notes


def txt2img(prompt: str, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, prompt_matrix: bool, ddim_eta: float, n_iter: int, batch_size: int, cfg_scale: float, seed: int, height: int, width: int):
Expand All @@ -438,7 +484,7 @@ def sample(init_data, x, conditioning, unconditional_conditioning):
samples_ddim, _ = sampler.sample(S=ddim_steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=cfg_scale, unconditional_conditioning=unconditional_conditioning, eta=ddim_eta, x_T=x)
return samples_ddim

output_images, seed, info = process_images(
output_images, seed, info, notes = process_images(
outpath=outpath,
func_init=init,
func_sample=sample,
Expand All @@ -457,7 +503,7 @@ def sample(init_data, x, conditioning, unconditional_conditioning):

del sampler

return output_images, seed, info
return output_images, seed, info, notes


class Flagging(gr.FlaggingCallback):
Expand Down Expand Up @@ -521,6 +567,7 @@ def flag(self, flag_data, flag_option=None, flag_index=None, username=None):
gr.Gallery(label="Images"),
gr.Number(label='Seed'),
gr.Textbox(label="Copy-paste generation parameters"),
gr.HTML(label='Notes'),
],
title="Stable Diffusion Text-to-Image K",
description="Generate images from text with Stable Diffusion (using K-LMS)",
Expand Down Expand Up @@ -562,7 +609,7 @@ def sample(init_data, x, conditioning, unconditional_conditioning):
samples_ddim = K.sampling.sample_lms(model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': cfg_scale}, disable=False)
return samples_ddim

output_images, seed, info = process_images(
output_images, seed, info, notes = process_images(
outpath=outpath,
func_init=init,
func_sample=sample,
Expand All @@ -581,7 +628,7 @@ def sample(init_data, x, conditioning, unconditional_conditioning):

del sampler

return output_images, seed, info
return output_images, seed, info, notes


sample_img2img = "assets/stable-samples/img2img/sketch-mountains-input.jpg"
Expand All @@ -608,6 +655,7 @@ def sample(init_data, x, conditioning, unconditional_conditioning):
gr.Gallery(),
gr.Number(label='Seed'),
gr.Textbox(label="Copy-paste generation parameters"),
gr.HTML(label='Notes'),
],
title="Stable Diffusion Image-to-Image",
description="Generate images from images with Stable Diffusion",
Expand Down

0 comments on commit 95423e2

Please sign in to comment.