Skip to content

Commit

Permalink
retrieve available lora models from backend
Browse files Browse the repository at this point in the history
  • Loading branch information
xieyongliang committed Apr 18, 2023
1 parent 9a322ff commit 256fcbd
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 15 deletions.
45 changes: 32 additions & 13 deletions extensions-builtin/Lora/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@

from modules import shared, devices, sd_models, errors

import requests
import json

metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20}

re_digits = re.compile(r"\d+")
Expand Down Expand Up @@ -337,23 +340,39 @@ def lora_MultiheadAttention_load_state_dict(self, *args, **kwargs):
return torch.nn.MultiheadAttention_load_state_dict_before_lora(self, *args, **kwargs)


def list_available_loras():
def list_available_loras(sagemaker_endpoint=None, username=None):
available_loras.clear()

os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True)

candidates = \
glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.pt'), recursive=True) + \
glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.safetensors'), recursive=True) + \
glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.ckpt'), recursive=True)

for filename in sorted(candidates, key=str.lower):
if os.path.isdir(filename):
continue
if shared.cmd_opts.pureui:
print(sagemaker_endpoint)
print(username)
if sagemaker_endpoint:
api_endpoint = os.environ['api_endpoint']
params = {'module': 'Lora', 'endpoint_name': sagemaker_endpoint}
response = requests.get(url=f'{api_endpoint}/sd/models', params=params)

if response.status_code == 200:
items = json.loads(response.text)
for item in items:
name = os.path.splitext(item['model_name'])[0]

title = item['title']
available_loras[name] = title
else:
os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True)

candidates = \
glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.pt'), recursive=True) + \
glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.safetensors'), recursive=True) + \
glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.ckpt'), recursive=True)

for filename in sorted(candidates, key=str.lower):
if os.path.isdir(filename):
continue

name = os.path.splitext(os.path.basename(filename))[0]
name = os.path.splitext(os.path.basename(filename))[0]

available_loras[name] = LoraOnDisk(name, filename)
available_loras[name] = LoraOnDisk(name, filename)


available_loras = {}
Expand Down
18 changes: 16 additions & 2 deletions modules/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,7 +569,6 @@ def refresh():

def refresh_sagemaker_endpoints(request : gr.Request):
username = shared.get_webui_username(request)

refresh_method(username)
args = refreshed_args() if callable(refreshed_args) else refreshed_args

Expand All @@ -580,7 +579,6 @@ def refresh_sagemaker_endpoints(request : gr.Request):

def refresh_sd_models(request: gr.Request):
username = shared.get_webui_username(request)

refresh_method(username)
args = refreshed_args() if callable(refreshed_args) else refreshed_args

Expand All @@ -589,6 +587,16 @@ def refresh_sd_models(request: gr.Request):

return gr.update(**(args or {}))

def refresh_lora_models(sagemaker_endpoint,request:gr.Request):
username = shared.get_webui_username(request)
refresh_method(sagemaker_endpoint,username)
args = refreshed_args() if callable(refreshed_args) else refreshed_args

for k, v in args.items():
setattr(refresh_component, k, v)

return gr.update(**(args or {}))

def refresh_checkpoints(sagemaker_endpoint,request:gr.Request):
username = shared.get_webui_username(request)
refresh_method(sagemaker_endpoint,username)
Expand Down Expand Up @@ -618,6 +626,12 @@ def refresh_checkpoints(sagemaker_endpoint,request:gr.Request):
inputs=[shared.sagemaker_endpoint_component],
outputs=[refresh_component]
)
elif elem_id == 'refresh_sd_lora':
refresh_button.click(
fn=refresh_lora_models,
inputs=[shared.sagemaker_endpoint_component],
outputs=[refresh_component]
)
else:
refresh_button.click(
fn=refresh,
Expand Down

0 comments on commit 256fcbd

Please sign in to comment.