Skip to content

Commit

Permalink
update webui.py
Browse files Browse the repository at this point in the history
  • Loading branch information
xieyongliang committed Apr 7, 2023
1 parent 91afed1 commit 0cdb815
Showing 1 changed file with 92 additions and 1 deletion.
93 changes: 92 additions & 1 deletion webui.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,56 @@
import requests
import json
import uuid

from huggingface_hub import hf_hub_download
import shutil
import glob

if not cmd_opts.api:
from extensions.sd_dreambooth_extension.dreambooth.db_config import DreamboothConfig
from extensions.sd_dreambooth_extension.scripts.dreambooth import start_training_from_config, create_model
from extensions.sd_dreambooth_extension.scripts.dreambooth import performance_wizard, training_wizard
from extensions.sd_dreambooth_extension.dreambooth.db_concept import Concept
from modules import paths
import glob
elif not cmd_opts.pureui
import requests
cache = dict()
s3_client = boto3.client('s3')
s3_resource= boto3.resource('s3')

def s3_download(s3uri, path):
pos = s3uri.find('/', 5)
bucket = s3uri[5 : pos]
key = s3uri[pos + 1 : ]

s3_bucket = s3_resource.Bucket(bucket)
objs = list(s3_bucket.objects.filter(Prefix=key))

if os.path.isfile('cache'):
cache = json.load(open('cache', 'r'))

for obj in objs:
if obj.key == key:
continue
response = s3_client.head_object(
Bucket = bucket,
Key = obj.key
)
obj_key = 's3://{0}/{1}'.format(bucket, obj.key)
if obj_key not in cache or cache[obj_key] != response['ETag']:
filename = obj.key[obj.key.rfind('/') + 1 : ]

s3_client.download_file(bucket, obj.key, os.path.join(path, filename))
cache[obj_key] = response['ETag']

json.dump(cache, open('cache', 'w'))

def http_download(httpuri, path):
with requests.get(httpuri, stream=True) as r:
r.raise_for_status()
with open(path, 'wb') as f:
for chunk in r.iter_content(chunk_size=8192):
f.write(chunk)

if cmd_opts.server_name:
server_name = cmd_opts.server_name
Expand Down Expand Up @@ -194,6 +237,54 @@ def user_auth(username, password):

def webui():
launch_api = cmd_opts.api

if launch_api:
models_config_s3uri = os.environ.get('models_config_s3uri', None)
if models_config_s3uri:
bucket, key = get_bucket_and_key(models_config_s3uri)
s3_object = s3_client.get_object(Bucket=bucket, Key=key)
bytes = s3_object["Body"].read()
payload = bytes.decode('utf8')
huggingface_models = json.loads(payload).get('huggingface_models', None)
s3_models = json.loads(payload).get('s3_models', None)
http_models = json.loads(payload).get('http_models', None)
else:
huggingface_models = os.environ.get('huggingface_models', None)
s3_models = os.environ.get('s3_models', None)
http_models = os.environ.get('http_models', None)

if huggingface_models:
huggingface_models = json.loads(huggingface_models)
huggingface_token = huggingface_models['token']
os.system(f'huggingface-cli login --token {huggingface_token}')
hf_hub_models = huggingface_models['models']
for huggingface_model in hf_hub_models:
repo_id = huggingface_model['repo_id']
filename = huggingface_model['filename']
name = huggingface_model['name']

hf_hub_download(
repo_id=repo_id,
filename=filename,
local_dir=f'/tmp/models/{name}',
cache_dir='/tmp/cache/huggingface'
)

if s3_models:
s3_models = json.loads(s3_models)
for s3_model in s3_models:
uri = s3_model['uri']
name = s3_model['name']
s3_download(uri, f'/tmp/models/{name}')

if http_models:
http_models = json.loads(http_models)
for http_model in http_models:
uri = http_model['uri']
filename = http_model['filename']
name = http_model['name']
http_download(uri, f'/tmp/models/{name}/{filename}')

initialize()

while 1:
Expand Down

0 comments on commit 0cdb815

Please sign in to comment.