Skip to content

Commit

Permalink
Add error response and handle options and exceptions
Browse files Browse the repository at this point in the history
  • Loading branch information
xieyongliang committed Apr 23, 2023
1 parent d71aeac commit a39b362
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 43 deletions.
91 changes: 48 additions & 43 deletions modules/api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def __init__(self, app: FastAPI, queue_lock: Lock):
self.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"])
self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"])
self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=ScriptsList)
self.add_api_route("/invocations", self.invocations, methods=["POST"], response_model=Union[TextToImageResponse, ImageToImageResponse, ExtrasSingleImageResponse, ExtrasBatchImagesResponse])
self.add_api_route("/invocations", self.invocations, methods=["POST"], response_model=Union[TextToImageResponse, ImageToImageResponse, ExtrasSingleImageResponse, ExtrasBatchImagesResponse, InvocationsErrorResponse])
self.add_api_route("/ping", self.ping, methods=["GET"], response_model=PingResponse)

self.default_script_arg_txt2img = []
Expand Down Expand Up @@ -739,51 +739,56 @@ def invocations(self, req: InvocationsRequest):
print('-------invocation------')
print(req)

if req.vae != None:
shared.opts.data['sd_vae'] = req.vae
refresh_vae_list()

if req.model != None:
sd_model_checkpoint = shared.opts.sd_model_checkpoint
shared.opts.sd_model_checkpoint = req.model
with self.queue_lock:
reload_model_weights()
if sd_model_checkpoint == shared.opts.sd_model_checkpoint:
reload_vae_weights()

quality = req.quality

embeddings_s3uri = shared.cmd_opts.embeddings_s3uri
hypernetwork_s3uri = shared.cmd_opts.hypernetwork_s3uri

self.download_s3files(hypernetwork_s3uri, os.path.join(script_path, shared.cmd_opts.hypernetwork_dir))
shared.reload_hypernetworks()

try:
if req.task == 'text-to-image':
self.download_s3files(embeddings_s3uri, os.path.join(script_path, shared.cmd_opts.embeddings_dir))
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
response = self.text2imgapi(req.txt2img_payload)
response.images = self.post_invocations(response.images, quality)
return response
elif req.task == 'image-to-image':
self.download_s3files(embeddings_s3uri, os.path.join(script_path, shared.cmd_opts.embeddings_dir))
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
response = self.img2imgapi(req.img2img_payload)
response.images = self.post_invocations(response.images, quality)
return response
elif req.task == 'extras-single-image':
response = self.extras_single_image_api(req.extras_single_payload)
response.image = self.post_invocations([response.image], quality)[0]
return response
elif req.task == 'extras-batch-images':
response = self.extras_batch_images_api(req.extras_batch_payload)
response.images = self.post_invocations(response.images, quality)
return response
else:
raise NotImplementedError
if req.vae != None:
shared.opts.data['sd_vae'] = req.vae
refresh_vae_list()

if req.model != None:
sd_model_checkpoint = shared.opts.sd_model_checkpoint
shared.opts.sd_model_checkpoint = req.model
with self.queue_lock:
reload_model_weights()
if sd_model_checkpoint == shared.opts.sd_model_checkpoint:
reload_vae_weights()

quality = req.quality

embeddings_s3uri = shared.cmd_opts.embeddings_s3uri
hypernetwork_s3uri = shared.cmd_opts.hypernetwork_s3uri

self.download_s3files(hypernetwork_s3uri, os.path.join(script_path, shared.cmd_opts.hypernetwork_dir))
shared.reload_hypernetworks()

if req.options != None:
options = json.lods(req.options)
for key in options:
shared.opts.data[key] = options[key]
if req.task == 'text-to-image':
self.download_s3files(embeddings_s3uri, os.path.join(script_path, shared.cmd_opts.embeddings_dir))
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
response = self.text2imgapi(req.txt2img_payload)
response.images = self.post_invocations(response.images, quality)
return response
elif req.task == 'image-to-image':
self.download_s3files(embeddings_s3uri, os.path.join(script_path, shared.cmd_opts.embeddings_dir))
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
response = self.img2imgapi(req.img2img_payload)
response.images = self.post_invocations(response.images, quality)
return response
elif req.task == 'extras-single-image':
response = self.extras_single_image_api(req.extras_single_payload)
response.image = self.post_invocations([response.image], quality)[0]
return response
elif req.task == 'extras-batch-images':
response = self.extras_batch_images_api(req.extras_batch_payload)
response.images = self.post_invocations(response.images, quality)
return response
else:
return InvocationsErrorResponse(error = f'Invalid task - {req.task}')
except Exception as e:
traceback.print_exc()
return InvocationsErrorResponse(error = str(e))

def ping(self):
print('-------ping------')
Expand Down
4 changes: 4 additions & 0 deletions modules/api/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,10 +296,14 @@ class InvocationsRequest(BaseModel):
model: Optional[str]
vae: Optional[str]
quality: Optional[int]
options: Optional[str]
txt2img_payload: Optional[StableDiffusionTxt2ImgProcessingAPI]
img2img_payload: Optional[StableDiffusionImg2ImgProcessingAPI]
extras_single_payload: Optional[ExtrasSingleImageRequest]
extras_batch_payload: Optional[ExtrasBatchImagesRequest]

class InvocationsErrorResponse(BaseModel):
error: str = Field(title="Invocation error", description="Error response from invocation.")

class PingResponse(BaseModel):
status: str

0 comments on commit a39b362

Please sign in to comment.