Skip to content

Commit

Permalink
Merge pull request AUTOMATIC1111#7 from xiehust/dev
Browse files Browse the repository at this point in the history
River's dev branch merge
  • Loading branch information
xieyongliang authored Apr 12, 2023
2 parents f352ab2 + 78afa45 commit 06824a5
Show file tree
Hide file tree
Showing 13 changed files with 751 additions and 111 deletions.
14 changes: 13 additions & 1 deletion localizations/zh_CN.json
Original file line number Diff line number Diff line change
Expand Up @@ -839,6 +839,18 @@
"Amount of time to pause between Epochs (s)": "Epochs 间隔等待时间",
"Save Preview(s) Frequency (Epochs)": "保存预览频率 (Epochs)",
"A generic prompt used to generate a sample image to verify model fidelity.": "用于生成样本图像以验证模型保真度的通用提示。",

"Job detail":"训练任务详情",
"S3 bucket name for uploading/downloading images":"上传训练图片集或者下载生成图片的S3桶名",
"Output S3 folder":"S3文件夹目录",
"Upload Train Images to S3":"上传训练图片到S3",
"Error, please configure a S3 bucket at settings page first":"失败,请先到设置页面配置S3桶名",
"Upload Images":"上传图片",
"Reload all models":"重新加载模型文件",
"Update model files path":"更新模型加载路径",
"S3 path for downloading model files (E.g, s3://bucket-name/models/)":"加载模型的S3路径,例如:s3://bucket-name/models/",
"Images Viewer":"图片浏览器",
"Input S3 path of images":"输入图片的S3路径",
"Submit":"确定",
"columns width":"每行图片列数",
"--------": "--------"
}
57 changes: 54 additions & 3 deletions modules/api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from fastapi import APIRouter, Depends, FastAPI, HTTPException
from fastapi.security import HTTPBasic, HTTPBasicCredentials
from secrets import compare_digest

from modules.shared import de_register_model
import modules.shared as shared
from modules import sd_samplers, deepbooru
from modules.api.models import *
Expand Down Expand Up @@ -427,6 +427,7 @@ def invocations(self, req: InvocationsRequest):
try:
username = req.username
default_options = shared.opts.data

if username != '':
inputs = {
'action': 'get',
Expand All @@ -446,7 +447,10 @@ def invocations(self, req: InvocationsRequest):
self.download_s3files(hypernetwork_s3uri, os.path.join(script_path, shared.cmd_opts.hypernetwork_dir))
hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork)
hypernetworks.hypernetwork.apply_strength()

##add sd model usage stats by River
print(f'default_options:{shared.opts.data}')
shared.sd_models_Ref.add_models_ref(shared.opts.data['sd_model_checkpoint'])
##end
if req.task == 'text-to-image':
self.download_s3files(embeddings_s3uri, os.path.join(script_path, shared.cmd_opts.embeddings_dir))
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
Expand All @@ -471,15 +475,62 @@ def invocations(self, req: InvocationsRequest):
self.post_invocations(username, response.images)
shared.opts.data = default_options
return response
elif req.task == 'reload-all-models':
return self.reload_all_models()
elif req.task == 'set-models-bucket':
bucket = req.models_bucket
return self.set_models_bucket(bucket)
else:
raise NotImplementedError
except Exception as e:
traceback.print_exc()

def ping(self):
print('-------ping------')
# print('-------ping------')
return {'status': 'Healthy'}

def reload_all_models(self):
print('-------reload_all_models------')
def remove_files(path):
for file_name in os.listdir(path):
file_path = os.path.join(path, file_name)
if os.path.isfile(file_path):
os.remove(file_path)
print(f'{file_path} has been removed')
if file_path.find('Stable-diffusion'):
de_register_model(file_name,'sd')
elif file_path.find('ControlNet'):
de_register_model(file_name,'cn')
elif os.path.isdir(file_path):
remove_files(file_path)
os.rmdir(file_path)
shared.syncLock.acquire()
#remove all files in /tmp/models/ and /tmp/cache/
remove_files(shared.tmp_models_dir)
remove_files(shared.tmp_cache_dir)
shared.syncLock.release()
return {'simple_result':'success'}

def set_models_bucket(self,bucket):
shared.syncLock.acquire()
if bucket.endswith('/'):
bucket = bucket[:-1]
url_parts = bucket.replace('s3://','').split('/')
shared.models_s3_bucket = url_parts[0]
lastfolder = url_parts[-1]
if lastfolder == 'Stable-diffusion':
shared.s3_folder_sd = '/'.join(url_parts[1:])
elif lastfolder == 'ControlNet':
shared.s3_folder_cn = '/'.join(url_parts[1:])
else:
shared.s3_folder_sd = '/'.join(url_parts[1:]+['Stable-diffusion'])
shared.s3_folder_cn = '/'.join(url_parts[1:]+['ControlNet'])
print(f'set_models_bucket to {shared.models_s3_bucket}')
print(f'set_s3_folder_sd to {shared.s3_folder_sd}')
print(f'set_s3_folder_cn to {shared.s3_folder_cn}')
shared.syncLock.release()
return {'simple_result':'success'}

def launch(self, server_name, port):
self.app.include_router(self.router)
uvicorn.run(self.app, host=server_name, port=port)
Expand Down
5 changes: 3 additions & 2 deletions modules/api/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,11 +244,12 @@ class ArtistItem(BaseModel):
class InvocationsRequest(BaseModel):
task: str
username: Optional[str]
models_bucket:Optional[str]
simple_result:Optional[str]
txt2img_payload: Optional[StableDiffusionTxt2ImgProcessingAPI]
img2img_payload: Optional[StableDiffusionImg2ImgProcessingAPI]
extras_single_payload: Optional[ExtrasSingleImageRequest]
extras_batch_payload: Optional[ExtrasBatchImagesRequest]

class PingResponse(BaseModel):
status: str

status: str
4 changes: 2 additions & 2 deletions modules/call_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,7 @@ def f(request: gr.Request, *args, extra_outputs_array=extra_outputs, **kwargs):
t = time.perf_counter()

try:
if func.__name__ == 'f' or func.__name__ == 'run_settings':
if func.__name__ == 'f' or func.__name__ == 'run_settings' or func.__name__ == 'save_files':
res = list(func(username, *args, **kwargs))
else:
res = list(func(*args, **kwargs))
Expand Down Expand Up @@ -505,4 +505,4 @@ def f(request: gr.Request, *args, extra_outputs_array=extra_outputs, **kwargs):

return tuple(res)

return f
return f
6 changes: 4 additions & 2 deletions modules/paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
import os
import sys
import modules.safe

script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
models_path = os.path.join(script_path, "models")
## Change by River
# models_path = os.path.join(script_path, "models")
models_path = '/tmp/models'
##
sys.path.insert(0, script_path)

# search for directory of stable diffusion in following places
Expand Down
14 changes: 14 additions & 0 deletions modules/script_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def __init__(self, imgs, cols, rows):
callbacks_infotext_pasted=[],
callbacks_script_unloaded=[],
callbacks_before_ui=[],
callbacks_update_cn_models=[]
)


Expand Down Expand Up @@ -192,6 +193,14 @@ def before_ui_callback():
except Exception:
report_exception(c, 'before_ui')

##Add by River
def update_cn_models_callback():
for c in callback_map['callbacks_update_cn_models']:
try:
c.callback()
except Exception:
report_exception(c, 'callbacks_update_cn_models')
##End by River
def add_callback(callbacks, fun):
stack = [x for x in inspect.stack() if x.filename != __file__]
filename = stack[0].filename if len(stack) > 0 else 'unknown file'
Expand All @@ -214,6 +223,10 @@ def remove_callbacks_for_function(callback_func):
for callback_to_remove in [cb for cb in callback_list if cb.callback == callback_func]:
callback_list.remove(callback_to_remove)

##Add by River
def on_update_cn_models(callback):
add_callback(callback_map['callbacks_update_cn_models'], callback)
##End by River

def on_app_started(callback):
"""register a function to be called when the webui started, the gradio `Block` component and
Expand Down Expand Up @@ -320,3 +333,4 @@ def on_before_ui(callback):
"""register a function to be called before the UI is created."""

add_callback(callback_map['callbacks_before_ui'], callback)

3 changes: 1 addition & 2 deletions modules/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,6 @@ def load_scripts():
script_callbacks.clear_callbacks()

scripts_list = list_scripts("scripts", ".py")

syspath = sys.path

for scriptfile in sorted(scripts_list):
Expand All @@ -203,6 +202,7 @@ def load_scripts():
finally:
sys.path = syspath
current_basedir = paths.script_path
print('scripts_data',scripts_data)


def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
Expand Down Expand Up @@ -316,7 +316,6 @@ def run(self, p: StableDiffusionProcessing, *args):

if script_index == 0:
return None

script = self.selectable_scripts[script_index-1]

if script is None:
Expand Down
6 changes: 5 additions & 1 deletion modules/sd_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
model_dir = "Stable-diffusion"
model_path = os.path.abspath(os.path.join(models_path, model_dir))


CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name', 'config'])
checkpoints_list = {}
checkpoints_loaded = collections.OrderedDict()
Expand Down Expand Up @@ -81,7 +82,8 @@ def modeltitle(path, shorthash):
if shared.cmd_opts.pureui:
if sagemaker_endpoint:
params = {
'module': 'Stable-diffusion', 'endpoint_name': sagemaker_endpoint
'module': 'Stable-diffusion',
'endpoint_name': sagemaker_endpoint
}
response = requests.get(url=f'{api_endpoint}/sd/models', params=params)
if response.status_code == 200:
Expand Down Expand Up @@ -161,6 +163,8 @@ def model_hash(filename):


def select_checkpoint():
##add log by Rive
print('checkpoints_list:',checkpoints_list)
model_checkpoint = shared.opts.sd_model_checkpoint
checkpoint_info = checkpoints_list.get(model_checkpoint, None)
if checkpoint_info is not None:
Expand Down
Loading

0 comments on commit 06824a5

Please sign in to comment.