Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CodeCamp #81] Gradio gui of supporting inpainting inference #1601

Merged
merged 39 commits into from
Jan 30, 2023
Merged
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
c35fac9
第一次提交
xiaomile Nov 27, 2022
a9c2f60
第一次提交
xiaomile Nov 28, 2022
8669775
Merge branch 'open-mmlab:dev-1.x' into dev-1.x
xiaomile Nov 28, 2022
4fdd56d
第二次提交
xiaomile Nov 28, 2022
95b2ea9
Merge branch 'dev-1.x' of https://github.com/xiaomile/mmediting into …
xiaomile Nov 28, 2022
a41ee73
第三次提交
xiaomile Nov 28, 2022
6463195
第四次提交,修改 isort
xiaomile Nov 28, 2022
1bbdf29
第5次提交,isort调整
xiaomile Nov 28, 2022
79946e4
第五次提交,调整isort
xiaomile Nov 28, 2022
d87e4c5
第6次提交,调整yapf
xiaomile Nov 28, 2022
52e450a
Merge branch 'dev-1.x' into dev-1.x
xiaomile Nov 30, 2022
3e20021
第7次提交,针对部分类型修改
xiaomile Dec 2, 2022
bca3da1
第7次提交,针对部分类型修改
xiaomile Dec 2, 2022
fa1ab45
第7次提交,针对部分类型修改
xiaomile Dec 2, 2022
f4d9df2
第7次提交,针对部分类型修改
xiaomile Dec 2, 2022
68af8e7
Merge branch 'dev-1.x' into dev-1.x
xiaomile Dec 2, 2022
b0287b7
第八次提交,根据要求修改部分参数类型和函数返回类型
xiaomile Dec 5, 2022
91773a6
第八次提交,yapf调整
xiaomile Dec 5, 2022
e367464
Merge branch 'dev-1.x' into dev-1.x
xiaomile Dec 5, 2022
de6e2f7
第九次提交,img_normalize.py部分类型修改
xiaomile Dec 8, 2022
33ebc55
Merge branch 'dev-1.x' of https://github.com/xiaomile/mmediting into …
xiaomile Dec 8, 2022
39485bc
Merge branch 'dev-1.x' into dev-1.x
xiaomile Dec 8, 2022
f6f4339
Merge branch 'dev-1.x' into dev-1.x
xiaomile Dec 8, 2022
2431137
第十次提交,修改base_edit_model.py和base_mattor.py部分参数类型
xiaomile Dec 9, 2022
4bcbc86
Merge branch 'dev-1.x' of https://github.com/xiaomile/mmediting into …
xiaomile Dec 9, 2022
650ec60
Merge branch 'dev-1.x' into dev-1.x
xiaomile Dec 9, 2022
4547782
gradio实现inpaiting gui
xiaomile Jan 13, 2023
42d806a
Merge branch 'open-mmlab:dev-1.x' into dev-1.x
xiaomile Jan 13, 2023
13924f1
yapf isort 调整
xiaomile Jan 13, 2023
146c7c1
Merge branch 'dev-1.x' of https://github.com/xiaomile/mmediting into …
xiaomile Jan 13, 2023
62e3592
Merge branch 'dev-1.x' into dev-1.x
xiaomile Jan 25, 2023
9ca0c8d
send_notification函数和setting输入项调整
xiaomile Jan 29, 2023
c71e3a7
Merge branch 'dev-1.x' into dev-1.x
LeoXing1996 Jan 29, 2023
c9c00bc
lint 调整
xiaomile Jan 30, 2023
9ff136e
Merge branch 'dev-1.x' of https://github.com/xiaomile/mmediting into …
xiaomile Jan 30, 2023
47c7e41
double-quote-string 调整
xiaomile Jan 30, 2023
e643234
double-quote-string调整
xiaomile Jan 30, 2023
bb0409f
docformatter 调整
xiaomile Jan 30, 2023
499fa87
Merge branch 'dev-1.x' into dev-1.x
xiaomile Jan 30, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
325 changes: 325 additions & 0 deletions demo/gradio-demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,325 @@
# Copyright (c) OpenMMLab. All rights reserved.
import json
import os
import os.path as osp
import subprocess
import traceback
import warnings
from typing import Dict, List, Optional, Union

import cv2
import gradio as gr
import numpy as np
import torch
import yaml

from mmedit.apis.inferencers.inpainting_inferencer import InpaintingInferencer
from mmedit.utils import register_all_modules


class InpaintingGradio:
inpainting_supported_models = [
# inpainting models
'aot_gan',
'deepfillv1',
'deepfillv2',
'global_local',
'partial_conv',
]
inpainting_supported_models_cfg = {}
inpainting_supported_models_cfg_inited = False
pkg_path = ''
error_color = '#FF0000'
success_color = '#00FF00'
warning_color = '#FFFF00'
notice_message = ('', None, '')

def __init__(self,
model_name: str = None,
model_setting: int = None,
model_config: str = None,
model_ckpt: str = None,
device: torch.device = None,
extra_parameters: Dict = None,
seed: int = 2022,
**kwargs) -> None:
register_all_modules(init_default_scope=True)
InpaintingGradio.init_inference_supported_models_cfg()
self.model_name = model_name
self.model_setting = model_setting
self.model_config = model_config
self.model_ckpt = model_ckpt
self.device = device
self.extra_parameters = extra_parameters
self.seed = seed
if model_name or (model_config and model_ckpt):
inpainting_kwargs = {}
inpainting_kwargs.update(
self._get_inpainting_kwargs(model_name, model_setting,
model_config, model_ckpt,
extra_parameters))
self.inference = InpaintingInferencer(
device=device, seed=seed, **inpainting_kwargs)

def model_reconfig(self,
model_name: str = None,
model_setting: int = None,
model_config: str = None,
model_ckpt: str = None,
device: torch.device = None,
extra_parameters: Dict = None,
seed: int = 2022,
**kwargs) -> None:
inpainting_kwargs = {}
# if model_config:
# model_config = model_config.name
# if model_ckpt:
# model_ckpt = model_ckpt.name
if not model_name and model_setting:
self.send_notification(
'model_name should not be None when model_setting was used',
self.error_color, 'error')
return
elif (not model_config and not model_name) and model_ckpt:
self.send_notification(
'model_name and model_config should not be None when '
'model_ckpt was used', self.error_color, 'error')
return
elif (not model_ckpt and not model_name) and model_config:
self.send_notification(
'model_name and model_ckpt should not be None when '
'model_config was used', self.error_color, 'error')
return
inpainting_kwargs.update(
self._get_inpainting_kwargs(model_name, model_setting,
model_config, model_ckpt,
extra_parameters))
try:
self.inference = InpaintingInferencer(
device=device, seed=seed, **inpainting_kwargs)
except Exception as e:
self.send_notification('inference Exception:' + str(e),
self.error_color, 'error')
traceback.print_exc()
return
self.send_notification('Model config Finished!', self.success_color,
'success')

@staticmethod
def change_text2dict(input_text: str) -> Union[Dict, None]:
return_dict = None
try:
return_dict = json.loads(input_text)
except Exception as e:
InpaintingGradio.send_notification(
'Convert string to dict Exception:' + str(e),
InpaintingGradio.error_color, 'error')
return return_dict

@staticmethod
def get_package_path() -> str:
p = subprocess.Popen(
'pip show mmedit', shell=True, stdout=subprocess.PIPE)
out, err = p.communicate()
out = out.decode()
if 'Location' not in out:
InpaintingGradio.send_notification('module mmedit not found',
InpaintingGradio.error_color,
'error')
raise Exception('module mmedit not found')
package_path = out[out.find('Location') +
len('Location: '):].split('\r\n')[0] + os.sep
return package_path

def get_model_config(self, model_name: str) -> Dict:
"""Get the model configuration including model config and checkpoint
url.

Args:
model_name (str): Name of the model.
Returns:
dict: Model configuration.
"""
if model_name not in self.inpainting_supported_models:
self.send_notification(f'Model {model_name} is not supported.',
self.error_color, 'error')
raise ValueError(f'Model {model_name} is not supported.')
else:
return self.inpainting_supported_models_cfg[model_name]

@staticmethod
def init_inference_supported_models_cfg() -> None:
if not InpaintingGradio.inpainting_supported_models_cfg_inited:
InpaintingGradio.pkg_path = InpaintingGradio.get_package_path()
# all_cfgs_dir = osp.join(osp.dirname(__file__), '..', 'configs')
all_cfgs_dir = osp.join(InpaintingGradio.pkg_path, 'configs')
for model_name in InpaintingGradio.inpainting_supported_models:
meta_file_dir = osp.join(all_cfgs_dir, model_name,
'metafile.yml')
with open(meta_file_dir, 'r') as stream:
parsed_yaml = yaml.safe_load(stream)
InpaintingGradio.inpainting_supported_models_cfg[
model_name] = {}
InpaintingGradio.inpainting_supported_models_cfg[model_name][
'settings'] = parsed_yaml['Models'] # noqa
InpaintingGradio.inpainting_supported_models_cfg_inited = True

def _get_inpainting_kwargs(self, model_name: Optional[str],
model_setting: Optional[int],
model_config: Optional[str],
model_ckpt: Optional[str],
extra_parameters: Optional[Dict]) -> Dict:
"""Get the kwargs for the inpainting inferencer."""
kwargs = {}

if model_name:
cfgs = self.get_model_config(model_name)
# kwargs['task'] = cfgs['task']
setting_to_use = 0
if model_setting:
if isinstance(model_setting, str):
model_setting = int(
model_setting[0:model_setting.find(' - ')])
setting_to_use = model_setting
else:
model_setting = 0
if model_setting > len(
cfgs['settings']) - 1 or model_setting < -len(
cfgs['settings']):
self.send_notification(
f"model_setting out of range of {model_name}'s "
'cfgs settings', self.error_color, 'error')
config_dir = cfgs['settings'][setting_to_use]['Config']
config_dir = config_dir[config_dir.find('configs'):]
# kwargs['config'] = os.path.join(
# osp.dirname(__file__), '..', config_dir)
kwargs['config'] = os.path.join(self.pkg_path, config_dir)
kwargs['ckpt'] = cfgs['settings'][setting_to_use]['Weights']

if model_config:
if kwargs.get('config', None) is not None:
warnings.warn(
f'{model_name}\'s default config '
f'is overridden by {model_config}', UserWarning)
kwargs['config'] = model_config

if model_ckpt:
if kwargs.get('ckpt', None) is not None:
warnings.warn(
f'{model_name}\'s default checkpoint '
f'is overridden by {model_ckpt}', UserWarning)
kwargs['ckpt'] = model_ckpt

if extra_parameters:
kwargs['extra_parameters'] = extra_parameters

return kwargs

@staticmethod
def send_notification(msg: str, color: str, label: str) -> None:
InpaintingGradio.notice_message = (msg, color, label)

@staticmethod
def get_inpainting_supported_models() -> List:
"""static function for getting inpainting inference supported modes."""
return InpaintingGradio.inpainting_supported_models

def infer(self, input_img_arg: Dict) -> np.ndarray:
result = self.inference(
img=input_img_arg['image'], mask=input_img_arg['mask'])
result = cv2.cvtColor(result[1], cv2.COLOR_RGB2BGR)
return result

def run(self) -> None:
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
input_model_dropdown = gr.Dropdown(
choices=self.inpainting_supported_models,
value=self.model_name,
label='choose model')
input_setting = gr.Dropdown(
value=self.model_setting, label='choose setting')
input_config = gr.Textbox(
value=self.model_config, label='model_config_path')
input_ckpt = gr.Textbox(
value=self.model_ckpt, label='model_ckpt_path')
input_device_dropdown = gr.Dropdown(
choices=['cuda', 'cpu'],
label='choose device',
value='cuda')
input_extra_parameters_input = gr.Textbox(
value=json.dumps(self.extra_parameters),
label='extra_parameters')
input_extra_parameters = gr.JSON(
value=self.extra_parameters, label='extra_parameters')
input_seed = gr.Number(
value=self.seed, precision=0, label='seed')
config_button = gr.Button('CONFIG')
input_extra_parameters_input.blur(
self.change_text2dict, input_extra_parameters_input,
input_extra_parameters)

with gr.Column(visible=False) as output_col:
input_image = gr.Image(
image_mode='RGB',
tool='sketch',
type='filepath',
label='Input image')
infer_button = gr.Button('INFER')
infer_button.style(full_width=False)
output_image = gr.Image(
label='Output image', interactive=False)
output_image.style(height=500)
infer_button.click(
self.infer, inputs=input_image, outputs=output_image)

with gr.Row():
label = gr.Label(
value=self.notice_message[0],
color=self.notice_message[1],
label=self.notice_message[2])

def show_infer(*args) -> Dict:
self.model_reconfig(*args)
return {
output_col:
gr.update(visible=True),
label:
gr.update(
value=self.notice_message[0],
color=self.notice_message[1],
label=self.notice_message[2])
}

def change_setting(model_name: str) -> Dict:
settings = InpaintingGradio.inpainting_supported_models_cfg[
model_name]['settings']
return {
input_setting:
gr.update(choices=[
str(i) + ' - ' + settings[i]['Config']
for i in range(len(settings))
])
}

input_model_dropdown.change(
fn=change_setting,
inputs=input_model_dropdown,
outputs=input_setting)

config_button.click(show_infer, [
input_model_dropdown,
input_setting,
input_config,
input_ckpt,
input_device_dropdown,
input_extra_parameters,
input_seed,
], [output_col, label])
demo.launch()


if __name__ == '__main__':
inpaintingGradio = InpaintingGradio()
inpaintingGradio.run()