From 5069e40eeaf00c8d2044eedf9e37445f02e9b855 Mon Sep 17 00:00:00 2001 From: Christopher Chou Date: Sat, 31 Aug 2024 05:55:55 +0000 Subject: [PATCH 1/5] Towards multi-image --- fastchat/conversation.py | 16 ++++--- fastchat/serve/gradio_block_arena_vision.py | 26 +++++++---- .../serve/gradio_block_arena_vision_anony.py | 17 +++++-- .../serve/gradio_block_arena_vision_named.py | 16 +++++-- fastchat/serve/moderation/moderator.py | 44 +++++++++++-------- 5 files changed, 80 insertions(+), 39 deletions(-) diff --git a/fastchat/conversation.py b/fastchat/conversation.py index ac4cc4cf3..21078938e 100644 --- a/fastchat/conversation.py +++ b/fastchat/conversation.py @@ -365,12 +365,16 @@ def to_gradio_chatbot(self): if i % 2 == 0: if type(msg) is tuple: msg, images = msg - image = images[0] # Only one image on gradio at one time - if image.image_format == ImageFormat.URL: - img_str = f'user upload image' - elif image.image_format == ImageFormat.BYTES: - img_str = f'user upload image' - msg = img_str + msg.replace("\n", "").strip() + combined_image_str = "" + for image in images: + if image.image_format == ImageFormat.URL: + img_str = ( + f'user upload image' + ) + elif image.image_format == ImageFormat.BYTES: + img_str = f'user upload image' + combined_image_str += img_str + msg = combined_image_str + msg.replace("\n", "").strip() ret.append([msg, None]) else: diff --git a/fastchat/serve/gradio_block_arena_vision.py b/fastchat/serve/gradio_block_arena_vision.py index 67f70d95b..87f4b4c6a 100644 --- a/fastchat/serve/gradio_block_arena_vision.py +++ b/fastchat/serve/gradio_block_arena_vision.py @@ -180,7 +180,7 @@ def _prepare_text_with_image(state, text, images): # reset convo with new image state.conv = get_conversation_template(state.model_name) - text = text, [images[0]] + text = text, images return text @@ -192,9 +192,10 @@ def convert_images_to_conversation_format(images): MAX_NSFW_ENDPOINT_IMAGE_SIZE_IN_MB = 5 / 1.5 conv_images = [] if len(images) > 0: - conv_image = Image(url=images[0]) - conv_image.to_conversation_format(MAX_NSFW_ENDPOINT_IMAGE_SIZE_IN_MB) - conv_images.append(conv_image) + for image in images: + conv_image = Image(url=image) + conv_image.to_conversation_format(MAX_NSFW_ENDPOINT_IMAGE_SIZE_IN_MB) + conv_images.append(conv_image) return conv_images @@ -233,17 +234,26 @@ def add_text(state, model_selector, chat_input, context: Context, request: gr.Re images = convert_images_to_conversation_format(images) # Use the first state to get the moderation response because this is based on user input so it is independent of the model - moderation_image_input = images[0] if len(images) > 0 else None moderation_type_to_response_map = ( state.content_moderator.image_and_text_moderation_filter( - moderation_image_input, text, [state.model_name], do_moderation=False + images, text, [state.model_name], do_moderation=False ) ) text_flagged, nsfw_flag, csam_flag = ( moderation_type_to_response_map["text_moderation"]["flagged"], - moderation_type_to_response_map["nsfw_moderation"]["flagged"], - moderation_type_to_response_map["csam_moderation"]["flagged"], + any( + [ + response["flagged"] + for response in moderation_type_to_response_map["nsfw_moderation"] + ] + ), + any( + [ + response["flagged"] + for response in moderation_type_to_response_map["csam_moderation"] + ] + ), ) if csam_flag: diff --git a/fastchat/serve/gradio_block_arena_vision_anony.py b/fastchat/serve/gradio_block_arena_vision_anony.py index 33c920c9b..e44fb3124 100644 --- a/fastchat/serve/gradio_block_arena_vision_anony.py +++ b/fastchat/serve/gradio_block_arena_vision_anony.py @@ -322,16 +322,25 @@ def add_text( images = convert_images_to_conversation_format(images) # Use the first state to get the moderation response because this is based on user input so it is independent of the model - moderation_image_input = images[0] if len(images) > 0 else None moderation_type_to_response_map = states[ 0 ].content_moderator.image_and_text_moderation_filter( - moderation_image_input, text, model_list, do_moderation=True + images, text, model_list, do_moderation=True ) text_flagged, nsfw_flag, csam_flag = ( moderation_type_to_response_map["text_moderation"]["flagged"], - moderation_type_to_response_map["nsfw_moderation"]["flagged"], - moderation_type_to_response_map["csam_moderation"]["flagged"], + any( + [ + response["flagged"] + for response in moderation_type_to_response_map["nsfw_moderation"] + ] + ), + any( + [ + response["flagged"] + for response in moderation_type_to_response_map["csam_moderation"] + ] + ), ) if csam_flag: diff --git a/fastchat/serve/gradio_block_arena_vision_named.py b/fastchat/serve/gradio_block_arena_vision_named.py index 279eb391d..57e74b537 100644 --- a/fastchat/serve/gradio_block_arena_vision_named.py +++ b/fastchat/serve/gradio_block_arena_vision_named.py @@ -253,7 +253,7 @@ def add_text( images = convert_images_to_conversation_format(images) # Use the first state to get the moderation response because this is based on user input so it is independent of the model - moderation_image_input = images[0] if len(images) > 0 else None + moderation_image_input = images if len(images) > 0 else None moderation_type_to_response_map = states[ 0 ].content_moderator.image_and_text_moderation_filter( @@ -262,8 +262,18 @@ def add_text( text_flagged, nsfw_flag, csam_flag = ( moderation_type_to_response_map["text_moderation"]["flagged"], - moderation_type_to_response_map["nsfw_moderation"]["flagged"], - moderation_type_to_response_map["csam_moderation"]["flagged"], + any( + [ + response["flagged"] + for response in moderation_type_to_response_map["nsfw_moderation"] + ] + ), + any( + [ + response["flagged"] + for response in moderation_type_to_response_map["csam_moderation"] + ] + ), ) if csam_flag: diff --git a/fastchat/serve/moderation/moderator.py b/fastchat/serve/moderation/moderator.py index 96bb30ae0..2a0f64eb7 100644 --- a/fastchat/serve/moderation/moderator.py +++ b/fastchat/serve/moderation/moderator.py @@ -35,7 +35,7 @@ def _text_moderation_filter(self, text: str) -> bool: raise NotImplementedError def image_and_text_moderation_filter( - self, image: Image, text: str + self, images: List[Image], text: str ) -> Dict[str, Dict[str, Union[str, Dict[str, float]]]]: """Function that detects whether image and text violate moderation policies. @@ -121,25 +121,33 @@ def _image_moderation_provider(self, image_bytes: bytes, api_type: str) -> bool: return moderation_response_map - def image_moderation_filter(self, image: Image): - print(f"moderating image") + def image_moderation_filter(self, images: List[Image]): + print(f"moderating images") - image_bytes = base64.b64decode(image.base64_str) + images_moderation_response: Dict[ + str, List[Dict[str, Union[str, Dict[str, float]]]] + ] = { + "nsfw_moderation": [], + "csam_moderation": [], + } - nsfw_flagged_map = self._image_moderation_provider(image_bytes, "nsfw") + for image in images: + image_bytes = base64.b64decode(image.base64_str) - if nsfw_flagged_map["flagged"]: - csam_flagged_map = self._image_moderation_provider(image_bytes, "csam") - else: - csam_flagged_map = {"flagged": False} + nsfw_flagged_map = self._image_moderation_provider(image_bytes, "nsfw") - self.nsfw_flagged = nsfw_flagged_map["flagged"] - self.csam_flagged = csam_flagged_map["flagged"] + if nsfw_flagged_map["flagged"]: + csam_flagged_map = self._image_moderation_provider(image_bytes, "csam") + else: + csam_flagged_map = {"flagged": False} - return { - "nsfw_moderation": nsfw_flagged_map, - "csam_moderation": csam_flagged_map, - } + self.nsfw_flagged |= nsfw_flagged_map["flagged"] + self.csam_flagged |= csam_flagged_map["flagged"] + + images_moderation_response["nsfw_moderation"].append(nsfw_flagged_map) + images_moderation_response["csam_moderation"].append(csam_flagged_map) + + return images_moderation_response def _openai_moderation_filter( self, text: str, custom_thresholds: dict = None @@ -212,7 +220,7 @@ def text_moderation_filter( return {"text_moderation": moderation_response_map} def image_and_text_moderation_filter( - self, image: Image, text: str, model_list: List[str], do_moderation=True + self, images: List[Image], text: str, model_list: List[str], do_moderation=True ) -> Dict[str, bool]: """Function that detects whether image and text violate moderation policies using the Azure and OpenAI moderation APIs. @@ -247,8 +255,8 @@ def image_and_text_moderation_filter( print("moderating text: ", text) text_flagged_map = self.text_moderation_filter(text, model_list, do_moderation) - if image is not None: - image_flagged_map = self.image_moderation_filter(image) + if images: + image_flagged_map = self.image_moderation_filter(images) else: image_flagged_map = self._NON_TOXIC_IMAGE_MODERATION_MAP From bdb074d1489a553e3ac4c88041b3a91152c7bdda Mon Sep 17 00:00:00 2001 From: Christopher Chou Date: Sat, 31 Aug 2024 06:53:11 +0000 Subject: [PATCH 2/5] Basic prototype working --- fastchat/serve/gradio_block_arena_vision.py | 4 - .../serve/gradio_block_arena_vision_anony.py | 145 ++++++++---------- fastchat/serve/moderation/moderator.py | 4 +- 3 files changed, 67 insertions(+), 86 deletions(-) diff --git a/fastchat/serve/gradio_block_arena_vision.py b/fastchat/serve/gradio_block_arena_vision.py index 87f4b4c6a..3ca764da7 100644 --- a/fastchat/serve/gradio_block_arena_vision.py +++ b/fastchat/serve/gradio_block_arena_vision.py @@ -176,10 +176,6 @@ def report_csam_image(state, image): def _prepare_text_with_image(state, text, images): if len(images) > 0: - if len(state.conv.get_images()) > 0: - # reset convo with new image - state.conv = get_conversation_template(state.model_name) - text = text, images return text diff --git a/fastchat/serve/gradio_block_arena_vision_anony.py b/fastchat/serve/gradio_block_arena_vision_anony.py index e44fb3124..9f8fa7f38 100644 --- a/fastchat/serve/gradio_block_arena_vision_anony.py +++ b/fastchat/serve/gradio_block_arena_vision_anony.py @@ -138,7 +138,7 @@ def clear_history_example(request: gr.Request): [None] * num_sides + [None] * num_sides + anony_names - + [enable_multimodal_keep_input, invisible_text, invisible_btn] + + [enable_multimodal_keep_input] + [invisible_btn] * 4 + [disable_btn] * 2 + [enable_btn] @@ -244,7 +244,7 @@ def clear_history(request: gr.Request): [None] * num_sides + [None] * num_sides + anony_names - + [enable_multimodal_clear_input, invisible_text, invisible_btn] + + [enable_multimodal_clear_input] + [invisible_btn] * 4 + [disable_btn] * 2 + [enable_btn] @@ -308,7 +308,7 @@ def add_text( return ( states + [x.to_gradio_chatbot() for x in states] - + [None, "", no_change_btn] + + [None] + [ no_change_btn, ] @@ -359,7 +359,7 @@ def add_text( return ( states + [x.to_gradio_chatbot() for x in states] - + [{"text": CONVERSATION_LIMIT_MSG}, "", no_change_btn] + + [{"text": CONVERSATION_LIMIT_MSG}] + [ no_change_btn, ] @@ -382,8 +382,6 @@ def add_text( + gradio_chatbot_list + [ None, - "", - no_change_btn, ] + [disable_btn] * 7 + [""] @@ -404,7 +402,7 @@ def add_text( return ( states + [x.to_gradio_chatbot() for x in states] - + [disable_multimodal, visible_text, enable_btn] + + [None] + [ disable_btn, ] @@ -496,12 +494,12 @@ def build_side_by_side_vision_ui_anony(context: Context, random_questions=None): ) with gr.Row(): - textbox = gr.Textbox( - show_label=False, - placeholder="👉 Enter your prompt and press ENTER", - elem_id="input_box", - visible=False, - ) + # textbox = gr.Textbox( + # show_label=False, + # placeholder="👉 Enter your prompt and press ENTER", + # elem_id="input_box", + # visible=False, + # ) multimodal_textbox = gr.MultimodalTextbox( file_types=["image"], @@ -510,9 +508,9 @@ def build_side_by_side_vision_ui_anony(context: Context, random_questions=None): placeholder="Enter your prompt or add image here", elem_id="input_box", ) - send_btn = gr.Button( - value="Send", variant="primary", scale=0, visible=False, interactive=False - ) + # send_btn = gr.Button( + # value="Send", variant="primary", scale=0, visible=False, interactive=False + # ) with gr.Row() as button_row: if random_questions: @@ -564,25 +562,29 @@ def build_side_by_side_vision_ui_anony(context: Context, random_questions=None): leftvote_btn.click( leftvote_last_response, states + model_selectors, - model_selectors + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn], + model_selectors + + [multimodal_textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn], ) rightvote_btn.click( rightvote_last_response, states + model_selectors, - model_selectors + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn], + model_selectors + + [multimodal_textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn], ) tie_btn.click( tievote_last_response, states + model_selectors, - model_selectors + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn], + model_selectors + + [multimodal_textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn], ) bothbad_btn.click( bothbad_vote_last_response, states + model_selectors, - model_selectors + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn], + model_selectors + + [multimodal_textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn], ) regenerate_btn.click( - regenerate, states, states + chatbots + [textbox] + btn_list + regenerate, states, states + chatbots + [multimodal_textbox] + btn_list ).then( bot_response_multi, states + [temperature, top_p, max_output_tokens], @@ -596,7 +598,7 @@ def build_side_by_side_vision_ui_anony(context: Context, random_questions=None): states + chatbots + model_selectors - + [multimodal_textbox, textbox, send_btn] + + [multimodal_textbox] + btn_list + [random_btn] + [slow_warning], @@ -626,14 +628,6 @@ def build_side_by_side_vision_ui_anony(context: Context, random_questions=None): multimodal_textbox.input(add_image, [multimodal_textbox], [imagebox]).then( set_visible_image, [multimodal_textbox], [image_column] - ).then( - clear_history_example, - None, - states - + chatbots - + model_selectors - + [multimodal_textbox, textbox, send_btn] - + btn_list, ) multimodal_textbox.submit( @@ -641,7 +635,7 @@ def build_side_by_side_vision_ui_anony(context: Context, random_questions=None): states + model_selectors + [multimodal_textbox, context_state], states + chatbots - + [multimodal_textbox, textbox, send_btn] + + [multimodal_textbox] + btn_list + [random_btn] + [slow_warning] @@ -656,60 +650,51 @@ def build_side_by_side_vision_ui_anony(context: Context, random_questions=None): btn_list, ) - textbox.submit( - add_text, - states + model_selectors + [textbox, context_state], - states - + chatbots - + [multimodal_textbox, textbox, send_btn] - + btn_list - + [random_btn] - + [slow_warning] - + [show_vote_buttons], - ).then( - bot_response_multi, - states + [temperature, top_p, max_output_tokens], - states + chatbots + btn_list, - ).then( - flash_buttons, - [show_vote_buttons], - btn_list, - ) - - send_btn.click( - add_text, - states + model_selectors + [textbox, context_state], - states - + chatbots - + [multimodal_textbox, textbox, send_btn] - + btn_list - + [random_btn] - + [slow_warning] - + [show_vote_buttons], - ).then( - bot_response_multi, - states + [temperature, top_p, max_output_tokens], - states + chatbots + btn_list, - ).then( - flash_buttons, - [show_vote_buttons], - btn_list, - ) + # textbox.submit( + # add_text, + # states + model_selectors + [textbox, context_state], + # states + # + chatbots + # + [multimodal_textbox] + # + btn_list + # + [random_btn] + # + [slow_warning] + # + [show_vote_buttons], + # ).then( + # bot_response_multi, + # states + [temperature, top_p, max_output_tokens], + # states + chatbots + btn_list, + # ).then( + # flash_buttons, + # [show_vote_buttons], + # btn_list, + # ) + + # send_btn.click( + # add_text, + # states + model_selectors + [textbox, context_state], + # states + # + chatbots + # + [multimodal_textbox] + # + btn_list + # + [random_btn] + # + [slow_warning] + # + [show_vote_buttons], + # ).then( + # bot_response_multi, + # states + [temperature, top_p, max_output_tokens], + # states + chatbots + btn_list, + # ).then( + # flash_buttons, + # [show_vote_buttons], + # btn_list, + # ) if random_questions: random_btn.click( get_vqa_sample, # First, get the VQA sample [], # Pass the path to the VQA samples [multimodal_textbox, imagebox], # Outputs are textbox and imagebox - ).then(set_visible_image, [multimodal_textbox], [image_column]).then( - clear_history_example, - None, - states - + chatbots - + model_selectors - + [multimodal_textbox, textbox, send_btn] - + btn_list - + [random_btn], - ) + ).then(set_visible_image, [multimodal_textbox], [image_column]) return states + model_selectors diff --git a/fastchat/serve/moderation/moderator.py b/fastchat/serve/moderation/moderator.py index 2a0f64eb7..1005d564e 100644 --- a/fastchat/serve/moderation/moderator.py +++ b/fastchat/serve/moderation/moderator.py @@ -65,8 +65,8 @@ def update_last_moderation_response( class AzureAndOpenAIContentModerator(BaseContentModerator): _NON_TOXIC_IMAGE_MODERATION_MAP = { - "nsfw_moderation": {"flagged": False}, - "csam_moderation": {"flagged": False}, + "nsfw_moderation": [{"flagged": False}], + "csam_moderation": [{"flagged": False}], } def __init__(self, use_remote_storage: bool = False): From 3666d3e2707dc06f32dba1ad8268a38436776f8e Mon Sep 17 00:00:00 2001 From: Christopher Chou Date: Sat, 31 Aug 2024 07:17:33 +0000 Subject: [PATCH 3/5] Filter with api endpoint info --- fastchat/serve/gradio_block_arena_vision.py | 13 ++++++++++++- fastchat/serve/gradio_block_arena_vision_anony.py | 6 ++++-- fastchat/serve/gradio_global_state.py | 1 + fastchat/serve/gradio_web_server.py | 5 +++++ fastchat/serve/gradio_web_server_multi.py | 12 ++++++++++-- 5 files changed, 32 insertions(+), 5 deletions(-) diff --git a/fastchat/serve/gradio_block_arena_vision.py b/fastchat/serve/gradio_block_arena_vision.py index 3ca764da7..afb1dd66c 100644 --- a/fastchat/serve/gradio_block_arena_vision.py +++ b/fastchat/serve/gradio_block_arena_vision.py @@ -174,8 +174,19 @@ def report_csam_image(state, image): pass -def _prepare_text_with_image(state, text, images): +def _prepare_text_with_image( + state: State, text: str, images: List[Image], context: Context +): if len(images) > 0: + model_supports_multi_image = context.api_endpoint_info[state.model_name].get( + "multi_image", False + ) + if len(state.conv.get_images()) > 0 and not model_supports_multi_image: + gr.Warning( + f"The model does not support multiple images. Only the first image will be used." + ) + return text + text = text, images return text diff --git a/fastchat/serve/gradio_block_arena_vision_anony.py b/fastchat/serve/gradio_block_arena_vision_anony.py index 9f8fa7f38..0ee87cddd 100644 --- a/fastchat/serve/gradio_block_arena_vision_anony.py +++ b/fastchat/serve/gradio_block_arena_vision_anony.py @@ -373,7 +373,9 @@ def add_text( # We call this before appending the text so it does not appear in the UI gradio_chatbot_list = [x.to_gradio_chatbot() for x in states] for i in range(num_sides): - post_processed_text = _prepare_text_with_image(states[i], text, images) + post_processed_text = _prepare_text_with_image( + states[i], text, images, context + ) states[i].conv.append_message(states[i].conv.roles[0], post_processed_text) states[i].skip_next = True gr.Warning(MODERATION_MSG) @@ -390,7 +392,7 @@ def add_text( text = text[:BLIND_MODE_INPUT_CHAR_LEN_LIMIT] # Hard cut-off for i in range(num_sides): - post_processed_text = _prepare_text_with_image(states[i], text, images) + post_processed_text = _prepare_text_with_image(states[i], text, images, context) states[i].conv.append_message(states[i].conv.roles[0], post_processed_text) states[i].conv.append_message(states[i].conv.roles[1], None) states[i].skip_next = False diff --git a/fastchat/serve/gradio_global_state.py b/fastchat/serve/gradio_global_state.py index e05022f18..de911985d 100644 --- a/fastchat/serve/gradio_global_state.py +++ b/fastchat/serve/gradio_global_state.py @@ -8,3 +8,4 @@ class Context: all_text_models: List[str] = field(default_factory=list) vision_models: List[str] = field(default_factory=list) all_vision_models: List[str] = field(default_factory=list) + api_endpoint_info: dict = field(default_factory=dict) diff --git a/fastchat/serve/gradio_web_server.py b/fastchat/serve/gradio_web_server.py index c4cef2e19..46ff0c421 100644 --- a/fastchat/serve/gradio_web_server.py +++ b/fastchat/serve/gradio_web_server.py @@ -228,6 +228,11 @@ def get_model_list(controller_url, register_api_endpoint_file, vision_arena): return visible_models, models +def _get_api_endpoint_info(): + global api_endpoint_info + return api_endpoint_info + + def load_demo_single(context: Context, query_params): # default to text models models = context.text_models diff --git a/fastchat/serve/gradio_web_server_multi.py b/fastchat/serve/gradio_web_server_multi.py index ade8abc6b..f19fab6ec 100644 --- a/fastchat/serve/gradio_web_server_multi.py +++ b/fastchat/serve/gradio_web_server_multi.py @@ -41,6 +41,7 @@ get_model_list, load_demo_single, get_ip, + _get_api_endpoint_info, ) from fastchat.serve.monitor.monitor import build_leaderboard_tab from fastchat.utils import ( @@ -318,8 +319,15 @@ def build_demo(context: Context, elo_results_file: str, leaderboard_table_file): args.register_api_endpoint_file, vision_arena=True, ) - - context = Context(text_models, all_text_models, vision_models, all_vision_models) + api_endpoint_info = _get_api_endpoint_info() + + context = Context( + text_models, + all_text_models, + vision_models, + all_vision_models, + api_endpoint_info, + ) # Set authorization credentials auth = None From 2bf97d4b3ad306dbd1022bfb907a7c378dce63d4 Mon Sep 17 00:00:00 2001 From: Christopher Chou Date: Sat, 31 Aug 2024 07:29:06 +0000 Subject: [PATCH 4/5] Add for non-anony and direct --- fastchat/serve/gradio_block_arena_vision.py | 87 +++++-------------- .../serve/gradio_block_arena_vision_anony.py | 50 ----------- .../serve/gradio_block_arena_vision_named.py | 78 ++++------------- fastchat/serve/moderation/moderator.py | 1 + 4 files changed, 41 insertions(+), 175 deletions(-) diff --git a/fastchat/serve/gradio_block_arena_vision.py b/fastchat/serve/gradio_block_arena_vision.py index afb1dd66c..692915c2e 100644 --- a/fastchat/serve/gradio_block_arena_vision.py +++ b/fastchat/serve/gradio_block_arena_vision.py @@ -155,18 +155,14 @@ def clear_history(request: gr.Request): ip = get_ip(request) logger.info(f"clear_history. ip: {ip}") state = None - return (state, [], enable_multimodal_clear_input, invisible_text, invisible_btn) + ( - disable_btn, - ) * 5 + return (state, [], enable_multimodal_clear_input) + (disable_btn,) * 5 def clear_history_example(request: gr.Request): ip = get_ip(request) logger.info(f"clear_history_example. ip: {ip}") state = None - return (state, [], enable_multimodal_keep_input, invisible_text, invisible_btn) + ( - disable_btn, - ) * 5 + return (state, [], enable_multimodal_keep_input) + (disable_btn,) * 5 # TODO(Chris): At some point, we would like this to be a live-reporting feature. @@ -231,9 +227,7 @@ def add_text(state, model_selector, chat_input, context: Context, request: gr.Re if len(text) <= 0: state.skip_next = True - return (state, state.to_gradio_chatbot(), None, "", no_change_btn) + ( - no_change_btn, - ) * 5 + return (state, state.to_gradio_chatbot(), None) + (no_change_btn,) * 5 all_conv_text = state.conv.get_prompt() all_conv_text = all_conv_text[-2000:] + "\nuser: " + text @@ -271,13 +265,15 @@ def add_text(state, model_selector, chat_input, context: Context, request: gr.Re if text_flagged or nsfw_flag: logger.info(f"violate moderation. ip: {ip}. text: {text}") gradio_chatbot_before_user_input = state.to_gradio_chatbot() - post_processed_text = _prepare_text_with_image(state, text, images) + post_processed_text = _prepare_text_with_image(state, text, images, context) state.conv.append_message(state.conv.roles[0], post_processed_text) state.skip_next = True gr.Warning(MODERATION_MSG) - return (state, gradio_chatbot_before_user_input, None, "", no_change_btn) + ( - no_change_btn, - ) * 5 + return ( + state, + gradio_chatbot_before_user_input, + None, + ) + (no_change_btn,) * 5 if (len(state.conv.messages) - state.conv.offset) // 2 >= CONVERSATION_TURN_LIMIT: logger.info(f"conversation turn limit. ip: {ip}. text: {text}") @@ -286,20 +282,16 @@ def add_text(state, model_selector, chat_input, context: Context, request: gr.Re state, state.to_gradio_chatbot(), {"text": CONVERSATION_LIMIT_MSG}, - "", - no_change_btn, ) + (no_change_btn,) * 5 text = text[:INPUT_CHAR_LEN_LIMIT] # Hard cut-off - text = _prepare_text_with_image(state, text, images) + text = _prepare_text_with_image(state, text, images, context) state.conv.append_message(state.conv.roles[0], text) state.conv.append_message(state.conv.roles[1], None) return ( state, state.to_gradio_chatbot(), - disable_multimodal, - visible_text, - enable_btn, + None, ) + (disable_btn,) * 5 @@ -361,17 +353,6 @@ def build_single_vision_language_model_ui( ) with gr.Row(): - textbox = gr.Textbox( - show_label=False, - placeholder="👉 Enter your prompt and press ENTER", - elem_id="input_box", - visible=False, - ) - - send_btn = gr.Button( - value="Send", variant="primary", scale=0, visible=False, interactive=False - ) - multimodal_textbox = gr.MultimodalTextbox( file_types=["image"], show_label=False, @@ -426,19 +407,21 @@ def build_single_vision_language_model_ui( upvote_btn.click( upvote_last_response, [state, model_selector], - [textbox, upvote_btn, downvote_btn, flag_btn], + [multimodal_textbox, upvote_btn, downvote_btn, flag_btn], ) downvote_btn.click( downvote_last_response, [state, model_selector], - [textbox, upvote_btn, downvote_btn, flag_btn], + [multimodal_textbox, upvote_btn, downvote_btn, flag_btn], ) flag_btn.click( flag_last_response, [state, model_selector], - [textbox, upvote_btn, downvote_btn, flag_btn], + [multimodal_textbox, upvote_btn, downvote_btn, flag_btn], ) - regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then( + regenerate_btn.click( + regenerate, state, [state, chatbot, multimodal_textbox] + btn_list + ).then( bot_response, [state, temperature, top_p, max_output_tokens], [state, chatbot] + btn_list, @@ -446,47 +429,23 @@ def build_single_vision_language_model_ui( clear_btn.click( clear_history, None, - [state, chatbot, multimodal_textbox, textbox, send_btn] + btn_list, + [state, chatbot, multimodal_textbox] + btn_list, ) model_selector.change( clear_history, None, - [state, chatbot, multimodal_textbox, textbox, send_btn] + btn_list, + [state, chatbot, multimodal_textbox] + btn_list, ).then(set_visible_image, [multimodal_textbox], [image_column]) multimodal_textbox.input(add_image, [multimodal_textbox], [imagebox]).then( set_visible_image, [multimodal_textbox], [image_column] - ).then( - clear_history_example, - None, - [state, chatbot, multimodal_textbox, textbox, send_btn] + btn_list, ) multimodal_textbox.submit( add_text, [state, model_selector, multimodal_textbox, context_state], - [state, chatbot, multimodal_textbox, textbox, send_btn] + btn_list, - ).then(set_invisible_image, [], [image_column]).then( - bot_response, - [state, temperature, top_p, max_output_tokens], - [state, chatbot] + btn_list, - ) - - textbox.submit( - add_text, - [state, model_selector, textbox, context_state], - [state, chatbot, multimodal_textbox, textbox, send_btn] + btn_list, - ).then(set_invisible_image, [], [image_column]).then( - bot_response, - [state, temperature, top_p, max_output_tokens], - [state, chatbot] + btn_list, - ) - - send_btn.click( - add_text, - [state, model_selector, textbox, context_state], - [state, chatbot, multimodal_textbox, textbox, send_btn] + btn_list, + [state, chatbot, multimodal_textbox] + btn_list, ).then(set_invisible_image, [], [image_column]).then( bot_response, [state, temperature, top_p, max_output_tokens], @@ -498,10 +457,6 @@ def build_single_vision_language_model_ui( get_vqa_sample, # First, get the VQA sample [], # Pass the path to the VQA samples [multimodal_textbox, imagebox], # Outputs are textbox and imagebox - ).then(set_visible_image, [multimodal_textbox], [image_column]).then( - clear_history_example, - None, - [state, chatbot, multimodal_textbox, textbox, send_btn] + btn_list, - ) + ).then(set_visible_image, [multimodal_textbox], [image_column]) return [state, model_selector] diff --git a/fastchat/serve/gradio_block_arena_vision_anony.py b/fastchat/serve/gradio_block_arena_vision_anony.py index 0ee87cddd..d24a54723 100644 --- a/fastchat/serve/gradio_block_arena_vision_anony.py +++ b/fastchat/serve/gradio_block_arena_vision_anony.py @@ -496,13 +496,6 @@ def build_side_by_side_vision_ui_anony(context: Context, random_questions=None): ) with gr.Row(): - # textbox = gr.Textbox( - # show_label=False, - # placeholder="👉 Enter your prompt and press ENTER", - # elem_id="input_box", - # visible=False, - # ) - multimodal_textbox = gr.MultimodalTextbox( file_types=["image"], show_label=False, @@ -510,9 +503,6 @@ def build_side_by_side_vision_ui_anony(context: Context, random_questions=None): placeholder="Enter your prompt or add image here", elem_id="input_box", ) - # send_btn = gr.Button( - # value="Send", variant="primary", scale=0, visible=False, interactive=False - # ) with gr.Row() as button_row: if random_questions: @@ -652,46 +642,6 @@ def build_side_by_side_vision_ui_anony(context: Context, random_questions=None): btn_list, ) - # textbox.submit( - # add_text, - # states + model_selectors + [textbox, context_state], - # states - # + chatbots - # + [multimodal_textbox] - # + btn_list - # + [random_btn] - # + [slow_warning] - # + [show_vote_buttons], - # ).then( - # bot_response_multi, - # states + [temperature, top_p, max_output_tokens], - # states + chatbots + btn_list, - # ).then( - # flash_buttons, - # [show_vote_buttons], - # btn_list, - # ) - - # send_btn.click( - # add_text, - # states + model_selectors + [textbox, context_state], - # states - # + chatbots - # + [multimodal_textbox] - # + btn_list - # + [random_btn] - # + [slow_warning] - # + [show_vote_buttons], - # ).then( - # bot_response_multi, - # states + [temperature, top_p, max_output_tokens], - # states + chatbots + btn_list, - # ).then( - # flash_buttons, - # [show_vote_buttons], - # btn_list, - # ) - if random_questions: random_btn.click( get_vqa_sample, # First, get the VQA sample diff --git a/fastchat/serve/gradio_block_arena_vision_named.py b/fastchat/serve/gradio_block_arena_vision_named.py index 57e74b537..218eb344e 100644 --- a/fastchat/serve/gradio_block_arena_vision_named.py +++ b/fastchat/serve/gradio_block_arena_vision_named.py @@ -99,7 +99,7 @@ def clear_history_example(request: gr.Request): return ( [None] * num_sides + [None] * num_sides - + [enable_multimodal_keep_input, invisible_text, invisible_btn] + + [enable_multimodal_keep_input] + [invisible_btn] * 4 + [disable_btn] * 2 ) @@ -184,7 +184,7 @@ def clear_history(request: gr.Request): return ( [None] * num_sides + [None] * num_sides - + [enable_multimodal_clear_input, invisible_text, invisible_btn] + + [enable_multimodal_clear_input] + [invisible_btn] * 4 + [disable_btn] * 2 ) @@ -236,7 +236,7 @@ def add_text( return ( states + [x.to_gradio_chatbot() for x in states] - + [None, "", no_change_btn] + + [None] + [ no_change_btn, ] @@ -295,7 +295,7 @@ def add_text( return ( states + [x.to_gradio_chatbot() for x in states] - + [{"text": CONVERSATION_LIMIT_MSG}, "", no_change_btn] + + [{"text": CONVERSATION_LIMIT_MSG}] + [ no_change_btn, ] @@ -306,14 +306,16 @@ def add_text( logger.info(f"violate moderation. ip: {ip}. text: {text}") gradio_chatbot_list = [x.to_gradio_chatbot() for x in states] for i in range(num_sides): - post_processed_text = _prepare_text_with_image(states[i], text, images) + post_processed_text = _prepare_text_with_image( + states[i], text, images, context + ) states[i].conv.append_message(states[i].conv.roles[0], post_processed_text) states[i].skip_next = True gr.Warning(MODERATION_MSG) return ( states + gradio_chatbot_list - + [None, "", no_change_btn] + + [None] + [ no_change_btn, ] @@ -326,6 +328,7 @@ def add_text( states[i], text, images, + context, ) states[i].conv.append_message(states[i].conv.roles[0], post_processed_text) states[i].conv.append_message(states[i].conv.roles[1], None) @@ -334,7 +337,7 @@ def add_text( return ( states + [x.to_gradio_chatbot() for x in states] - + [disable_multimodal, visible_text, enable_btn] + + [None] + [ disable_btn, ] @@ -427,17 +430,6 @@ def build_side_by_side_vision_ui_named(context: Context, random_questions=None): ) with gr.Row(): - textbox = gr.Textbox( - show_label=False, - placeholder="👉 Enter your prompt and press ENTER", - elem_id="input_box", - visible=False, - ) - - send_btn = gr.Button( - value="Send", variant="primary", scale=0, visible=False, interactive=False - ) - multimodal_textbox = gr.MultimodalTextbox( file_types=["image"], show_label=False, @@ -496,25 +488,25 @@ def build_side_by_side_vision_ui_named(context: Context, random_questions=None): leftvote_btn.click( leftvote_last_response, states + model_selectors, - [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn], + [multimodal_textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn], ) rightvote_btn.click( rightvote_last_response, states + model_selectors, - [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn], + [multimodal_textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn], ) tie_btn.click( tievote_last_response, states + model_selectors, - [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn], + [multimodal_textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn], ) bothbad_btn.click( bothbad_vote_last_response, states + model_selectors, - [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn], + [multimodal_textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn], ) regenerate_btn.click( - regenerate, states, states + chatbots + [textbox] + btn_list + regenerate, states, states + chatbots + [multimodal_textbox] + btn_list ).then( bot_response_multi, states + [temperature, top_p, max_output_tokens], @@ -525,7 +517,7 @@ def build_side_by_side_vision_ui_named(context: Context, random_questions=None): clear_btn.click( clear_history, None, - states + chatbots + [multimodal_textbox, textbox, send_btn] + btn_list, + states + chatbots + [multimodal_textbox] + btn_list, ) share_js = """ @@ -554,45 +546,17 @@ def build_side_by_side_vision_ui_named(context: Context, random_questions=None): model_selectors[i].change( clear_history, None, - states + chatbots + [multimodal_textbox, textbox, send_btn] + btn_list, + states + chatbots + [multimodal_textbox] + btn_list, ).then(set_visible_image, [multimodal_textbox], [image_column]) multimodal_textbox.input(add_image, [multimodal_textbox], [imagebox]).then( set_visible_image, [multimodal_textbox], [image_column] - ).then( - clear_history_example, - None, - states + chatbots + [multimodal_textbox, textbox, send_btn] + btn_list, ) multimodal_textbox.submit( add_text, states + model_selectors + [multimodal_textbox, context_state], - states + chatbots + [multimodal_textbox, textbox, send_btn] + btn_list, - ).then(set_invisible_image, [], [image_column]).then( - bot_response_multi, - states + [temperature, top_p, max_output_tokens], - states + chatbots + btn_list, - ).then( - flash_buttons, [], btn_list - ) - - textbox.submit( - add_text, - states + model_selectors + [textbox, context_state], - states + chatbots + [multimodal_textbox, textbox, send_btn] + btn_list, - ).then(set_invisible_image, [], [image_column]).then( - bot_response_multi, - states + [temperature, top_p, max_output_tokens], - states + chatbots + btn_list, - ).then( - flash_buttons, [], btn_list - ) - - send_btn.click( - add_text, - states + model_selectors + [textbox, context_state], - states + chatbots + [multimodal_textbox, textbox, send_btn] + btn_list, + states + chatbots + [multimodal_textbox] + btn_list, ).then(set_invisible_image, [], [image_column]).then( bot_response_multi, states + [temperature, top_p, max_output_tokens], @@ -606,10 +570,6 @@ def build_side_by_side_vision_ui_named(context: Context, random_questions=None): get_vqa_sample, # First, get the VQA sample [], # Pass the path to the VQA samples [multimodal_textbox, imagebox], # Outputs are textbox and imagebox - ).then(set_visible_image, [multimodal_textbox], [image_column]).then( - clear_history_example, - None, - states + chatbots + [multimodal_textbox, textbox, send_btn] + btn_list, - ) + ).then(set_visible_image, [multimodal_textbox], [image_column]) return states + model_selectors diff --git a/fastchat/serve/moderation/moderator.py b/fastchat/serve/moderation/moderator.py index 1005d564e..efafdcd88 100644 --- a/fastchat/serve/moderation/moderator.py +++ b/fastchat/serve/moderation/moderator.py @@ -211,6 +211,7 @@ def text_moderation_filter( do_moderation = True break + moderation_response_map = {"flagged": False} if do_moderation: moderation_response_map = self._openai_moderation_filter( text, custom_thresholds From 2d96a40653c4ae3e88bc4414e6d5b88c32090bd1 Mon Sep 17 00:00:00 2001 From: Christopher Chou Date: Sat, 31 Aug 2024 17:16:01 +0000 Subject: [PATCH 5/5] Fix --- fastchat/serve/gradio_block_arena_vision.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/fastchat/serve/gradio_block_arena_vision.py b/fastchat/serve/gradio_block_arena_vision.py index 692915c2e..dadb360e0 100644 --- a/fastchat/serve/gradio_block_arena_vision.py +++ b/fastchat/serve/gradio_block_arena_vision.py @@ -85,10 +85,6 @@ def set_visible_image(textbox): images = textbox["files"] if len(images) == 0: return invisible_image_column - elif len(images) > 1: - gr.Warning( - "We only support single image conversations. Please start a new round if you would like to chat using this image." - ) return visible_image_column @@ -177,11 +173,23 @@ def _prepare_text_with_image( model_supports_multi_image = context.api_endpoint_info[state.model_name].get( "multi_image", False ) - if len(state.conv.get_images()) > 0 and not model_supports_multi_image: + num_previous_images = len(state.conv.get_images()) + images_interleaved_with_text_exists_but_model_does_not_support = ( + num_previous_images > 0 and not model_supports_multi_image + ) + multiple_image_one_turn_but_model_does_not_support = ( + len(images) > 1 and not model_supports_multi_image + ) + if images_interleaved_with_text_exists_but_model_does_not_support: gr.Warning( - f"The model does not support multiple images. Only the first image will be used." + f"The model does not support interleaved image/text. We only use the very first image." ) return text + elif multiple_image_one_turn_but_model_does_not_support: + gr.Warning( + f"The model does not support multiple images. Only the first image will be used." + ) + return text, [images[0]] text = text, images