diff --git a/fastchat/conversation.py b/fastchat/conversation.py
index ac4cc4cf3..21078938e 100644
--- a/fastchat/conversation.py
+++ b/fastchat/conversation.py
@@ -365,12 +365,16 @@ def to_gradio_chatbot(self):
if i % 2 == 0:
if type(msg) is tuple:
msg, images = msg
- image = images[0] # Only one image on gradio at one time
- if image.image_format == ImageFormat.URL:
- img_str = f'
'
- elif image.image_format == ImageFormat.BYTES:
- img_str = f'
'
- msg = img_str + msg.replace("\n", "").strip()
+ combined_image_str = ""
+ for image in images:
+ if image.image_format == ImageFormat.URL:
+ img_str = (
+ f'
'
+ )
+ elif image.image_format == ImageFormat.BYTES:
+ img_str = f'
'
+ combined_image_str += img_str
+ msg = combined_image_str + msg.replace("\n", "").strip()
ret.append([msg, None])
else:
diff --git a/fastchat/serve/gradio_block_arena_vision.py b/fastchat/serve/gradio_block_arena_vision.py
index 67f70d95b..dadb360e0 100644
--- a/fastchat/serve/gradio_block_arena_vision.py
+++ b/fastchat/serve/gradio_block_arena_vision.py
@@ -85,10 +85,6 @@ def set_visible_image(textbox):
images = textbox["files"]
if len(images) == 0:
return invisible_image_column
- elif len(images) > 1:
- gr.Warning(
- "We only support single image conversations. Please start a new round if you would like to chat using this image."
- )
return visible_image_column
@@ -155,18 +151,14 @@ def clear_history(request: gr.Request):
ip = get_ip(request)
logger.info(f"clear_history. ip: {ip}")
state = None
- return (state, [], enable_multimodal_clear_input, invisible_text, invisible_btn) + (
- disable_btn,
- ) * 5
+ return (state, [], enable_multimodal_clear_input) + (disable_btn,) * 5
def clear_history_example(request: gr.Request):
ip = get_ip(request)
logger.info(f"clear_history_example. ip: {ip}")
state = None
- return (state, [], enable_multimodal_keep_input, invisible_text, invisible_btn) + (
- disable_btn,
- ) * 5
+ return (state, [], enable_multimodal_keep_input) + (disable_btn,) * 5
# TODO(Chris): At some point, we would like this to be a live-reporting feature.
@@ -174,13 +166,32 @@ def report_csam_image(state, image):
pass
-def _prepare_text_with_image(state, text, images):
+def _prepare_text_with_image(
+ state: State, text: str, images: List[Image], context: Context
+):
if len(images) > 0:
- if len(state.conv.get_images()) > 0:
- # reset convo with new image
- state.conv = get_conversation_template(state.model_name)
+ model_supports_multi_image = context.api_endpoint_info[state.model_name].get(
+ "multi_image", False
+ )
+ num_previous_images = len(state.conv.get_images())
+ images_interleaved_with_text_exists_but_model_does_not_support = (
+ num_previous_images > 0 and not model_supports_multi_image
+ )
+ multiple_image_one_turn_but_model_does_not_support = (
+ len(images) > 1 and not model_supports_multi_image
+ )
+ if images_interleaved_with_text_exists_but_model_does_not_support:
+ gr.Warning(
+ f"The model does not support interleaved image/text. We only use the very first image."
+ )
+ return text
+ elif multiple_image_one_turn_but_model_does_not_support:
+ gr.Warning(
+ f"The model does not support multiple images. Only the first image will be used."
+ )
+ return text, [images[0]]
- text = text, [images[0]]
+ text = text, images
return text
@@ -192,9 +203,10 @@ def convert_images_to_conversation_format(images):
MAX_NSFW_ENDPOINT_IMAGE_SIZE_IN_MB = 5 / 1.5
conv_images = []
if len(images) > 0:
- conv_image = Image(url=images[0])
- conv_image.to_conversation_format(MAX_NSFW_ENDPOINT_IMAGE_SIZE_IN_MB)
- conv_images.append(conv_image)
+ for image in images:
+ conv_image = Image(url=image)
+ conv_image.to_conversation_format(MAX_NSFW_ENDPOINT_IMAGE_SIZE_IN_MB)
+ conv_images.append(conv_image)
return conv_images
@@ -223,9 +235,7 @@ def add_text(state, model_selector, chat_input, context: Context, request: gr.Re
if len(text) <= 0:
state.skip_next = True
- return (state, state.to_gradio_chatbot(), None, "", no_change_btn) + (
- no_change_btn,
- ) * 5
+ return (state, state.to_gradio_chatbot(), None) + (no_change_btn,) * 5
all_conv_text = state.conv.get_prompt()
all_conv_text = all_conv_text[-2000:] + "\nuser: " + text
@@ -233,17 +243,26 @@ def add_text(state, model_selector, chat_input, context: Context, request: gr.Re
images = convert_images_to_conversation_format(images)
# Use the first state to get the moderation response because this is based on user input so it is independent of the model
- moderation_image_input = images[0] if len(images) > 0 else None
moderation_type_to_response_map = (
state.content_moderator.image_and_text_moderation_filter(
- moderation_image_input, text, [state.model_name], do_moderation=False
+ images, text, [state.model_name], do_moderation=False
)
)
text_flagged, nsfw_flag, csam_flag = (
moderation_type_to_response_map["text_moderation"]["flagged"],
- moderation_type_to_response_map["nsfw_moderation"]["flagged"],
- moderation_type_to_response_map["csam_moderation"]["flagged"],
+ any(
+ [
+ response["flagged"]
+ for response in moderation_type_to_response_map["nsfw_moderation"]
+ ]
+ ),
+ any(
+ [
+ response["flagged"]
+ for response in moderation_type_to_response_map["csam_moderation"]
+ ]
+ ),
)
if csam_flag:
@@ -254,13 +273,15 @@ def add_text(state, model_selector, chat_input, context: Context, request: gr.Re
if text_flagged or nsfw_flag:
logger.info(f"violate moderation. ip: {ip}. text: {text}")
gradio_chatbot_before_user_input = state.to_gradio_chatbot()
- post_processed_text = _prepare_text_with_image(state, text, images)
+ post_processed_text = _prepare_text_with_image(state, text, images, context)
state.conv.append_message(state.conv.roles[0], post_processed_text)
state.skip_next = True
gr.Warning(MODERATION_MSG)
- return (state, gradio_chatbot_before_user_input, None, "", no_change_btn) + (
- no_change_btn,
- ) * 5
+ return (
+ state,
+ gradio_chatbot_before_user_input,
+ None,
+ ) + (no_change_btn,) * 5
if (len(state.conv.messages) - state.conv.offset) // 2 >= CONVERSATION_TURN_LIMIT:
logger.info(f"conversation turn limit. ip: {ip}. text: {text}")
@@ -269,20 +290,16 @@ def add_text(state, model_selector, chat_input, context: Context, request: gr.Re
state,
state.to_gradio_chatbot(),
{"text": CONVERSATION_LIMIT_MSG},
- "",
- no_change_btn,
) + (no_change_btn,) * 5
text = text[:INPUT_CHAR_LEN_LIMIT] # Hard cut-off
- text = _prepare_text_with_image(state, text, images)
+ text = _prepare_text_with_image(state, text, images, context)
state.conv.append_message(state.conv.roles[0], text)
state.conv.append_message(state.conv.roles[1], None)
return (
state,
state.to_gradio_chatbot(),
- disable_multimodal,
- visible_text,
- enable_btn,
+ None,
) + (disable_btn,) * 5
@@ -344,17 +361,6 @@ def build_single_vision_language_model_ui(
)
with gr.Row():
- textbox = gr.Textbox(
- show_label=False,
- placeholder="👉 Enter your prompt and press ENTER",
- elem_id="input_box",
- visible=False,
- )
-
- send_btn = gr.Button(
- value="Send", variant="primary", scale=0, visible=False, interactive=False
- )
-
multimodal_textbox = gr.MultimodalTextbox(
file_types=["image"],
show_label=False,
@@ -409,19 +415,21 @@ def build_single_vision_language_model_ui(
upvote_btn.click(
upvote_last_response,
[state, model_selector],
- [textbox, upvote_btn, downvote_btn, flag_btn],
+ [multimodal_textbox, upvote_btn, downvote_btn, flag_btn],
)
downvote_btn.click(
downvote_last_response,
[state, model_selector],
- [textbox, upvote_btn, downvote_btn, flag_btn],
+ [multimodal_textbox, upvote_btn, downvote_btn, flag_btn],
)
flag_btn.click(
flag_last_response,
[state, model_selector],
- [textbox, upvote_btn, downvote_btn, flag_btn],
+ [multimodal_textbox, upvote_btn, downvote_btn, flag_btn],
)
- regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then(
+ regenerate_btn.click(
+ regenerate, state, [state, chatbot, multimodal_textbox] + btn_list
+ ).then(
bot_response,
[state, temperature, top_p, max_output_tokens],
[state, chatbot] + btn_list,
@@ -429,47 +437,23 @@ def build_single_vision_language_model_ui(
clear_btn.click(
clear_history,
None,
- [state, chatbot, multimodal_textbox, textbox, send_btn] + btn_list,
+ [state, chatbot, multimodal_textbox] + btn_list,
)
model_selector.change(
clear_history,
None,
- [state, chatbot, multimodal_textbox, textbox, send_btn] + btn_list,
+ [state, chatbot, multimodal_textbox] + btn_list,
).then(set_visible_image, [multimodal_textbox], [image_column])
multimodal_textbox.input(add_image, [multimodal_textbox], [imagebox]).then(
set_visible_image, [multimodal_textbox], [image_column]
- ).then(
- clear_history_example,
- None,
- [state, chatbot, multimodal_textbox, textbox, send_btn] + btn_list,
)
multimodal_textbox.submit(
add_text,
[state, model_selector, multimodal_textbox, context_state],
- [state, chatbot, multimodal_textbox, textbox, send_btn] + btn_list,
- ).then(set_invisible_image, [], [image_column]).then(
- bot_response,
- [state, temperature, top_p, max_output_tokens],
- [state, chatbot] + btn_list,
- )
-
- textbox.submit(
- add_text,
- [state, model_selector, textbox, context_state],
- [state, chatbot, multimodal_textbox, textbox, send_btn] + btn_list,
- ).then(set_invisible_image, [], [image_column]).then(
- bot_response,
- [state, temperature, top_p, max_output_tokens],
- [state, chatbot] + btn_list,
- )
-
- send_btn.click(
- add_text,
- [state, model_selector, textbox, context_state],
- [state, chatbot, multimodal_textbox, textbox, send_btn] + btn_list,
+ [state, chatbot, multimodal_textbox] + btn_list,
).then(set_invisible_image, [], [image_column]).then(
bot_response,
[state, temperature, top_p, max_output_tokens],
@@ -481,10 +465,6 @@ def build_single_vision_language_model_ui(
get_vqa_sample, # First, get the VQA sample
[], # Pass the path to the VQA samples
[multimodal_textbox, imagebox], # Outputs are textbox and imagebox
- ).then(set_visible_image, [multimodal_textbox], [image_column]).then(
- clear_history_example,
- None,
- [state, chatbot, multimodal_textbox, textbox, send_btn] + btn_list,
- )
+ ).then(set_visible_image, [multimodal_textbox], [image_column])
return [state, model_selector]
diff --git a/fastchat/serve/gradio_block_arena_vision_anony.py b/fastchat/serve/gradio_block_arena_vision_anony.py
index 33c920c9b..d24a54723 100644
--- a/fastchat/serve/gradio_block_arena_vision_anony.py
+++ b/fastchat/serve/gradio_block_arena_vision_anony.py
@@ -138,7 +138,7 @@ def clear_history_example(request: gr.Request):
[None] * num_sides
+ [None] * num_sides
+ anony_names
- + [enable_multimodal_keep_input, invisible_text, invisible_btn]
+ + [enable_multimodal_keep_input]
+ [invisible_btn] * 4
+ [disable_btn] * 2
+ [enable_btn]
@@ -244,7 +244,7 @@ def clear_history(request: gr.Request):
[None] * num_sides
+ [None] * num_sides
+ anony_names
- + [enable_multimodal_clear_input, invisible_text, invisible_btn]
+ + [enable_multimodal_clear_input]
+ [invisible_btn] * 4
+ [disable_btn] * 2
+ [enable_btn]
@@ -308,7 +308,7 @@ def add_text(
return (
states
+ [x.to_gradio_chatbot() for x in states]
- + [None, "", no_change_btn]
+ + [None]
+ [
no_change_btn,
]
@@ -322,16 +322,25 @@ def add_text(
images = convert_images_to_conversation_format(images)
# Use the first state to get the moderation response because this is based on user input so it is independent of the model
- moderation_image_input = images[0] if len(images) > 0 else None
moderation_type_to_response_map = states[
0
].content_moderator.image_and_text_moderation_filter(
- moderation_image_input, text, model_list, do_moderation=True
+ images, text, model_list, do_moderation=True
)
text_flagged, nsfw_flag, csam_flag = (
moderation_type_to_response_map["text_moderation"]["flagged"],
- moderation_type_to_response_map["nsfw_moderation"]["flagged"],
- moderation_type_to_response_map["csam_moderation"]["flagged"],
+ any(
+ [
+ response["flagged"]
+ for response in moderation_type_to_response_map["nsfw_moderation"]
+ ]
+ ),
+ any(
+ [
+ response["flagged"]
+ for response in moderation_type_to_response_map["csam_moderation"]
+ ]
+ ),
)
if csam_flag:
@@ -350,7 +359,7 @@ def add_text(
return (
states
+ [x.to_gradio_chatbot() for x in states]
- + [{"text": CONVERSATION_LIMIT_MSG}, "", no_change_btn]
+ + [{"text": CONVERSATION_LIMIT_MSG}]
+ [
no_change_btn,
]
@@ -364,7 +373,9 @@ def add_text(
# We call this before appending the text so it does not appear in the UI
gradio_chatbot_list = [x.to_gradio_chatbot() for x in states]
for i in range(num_sides):
- post_processed_text = _prepare_text_with_image(states[i], text, images)
+ post_processed_text = _prepare_text_with_image(
+ states[i], text, images, context
+ )
states[i].conv.append_message(states[i].conv.roles[0], post_processed_text)
states[i].skip_next = True
gr.Warning(MODERATION_MSG)
@@ -373,8 +384,6 @@ def add_text(
+ gradio_chatbot_list
+ [
None,
- "",
- no_change_btn,
]
+ [disable_btn] * 7
+ [""]
@@ -383,7 +392,7 @@ def add_text(
text = text[:BLIND_MODE_INPUT_CHAR_LEN_LIMIT] # Hard cut-off
for i in range(num_sides):
- post_processed_text = _prepare_text_with_image(states[i], text, images)
+ post_processed_text = _prepare_text_with_image(states[i], text, images, context)
states[i].conv.append_message(states[i].conv.roles[0], post_processed_text)
states[i].conv.append_message(states[i].conv.roles[1], None)
states[i].skip_next = False
@@ -395,7 +404,7 @@ def add_text(
return (
states
+ [x.to_gradio_chatbot() for x in states]
- + [disable_multimodal, visible_text, enable_btn]
+ + [None]
+ [
disable_btn,
]
@@ -487,13 +496,6 @@ def build_side_by_side_vision_ui_anony(context: Context, random_questions=None):
)
with gr.Row():
- textbox = gr.Textbox(
- show_label=False,
- placeholder="👉 Enter your prompt and press ENTER",
- elem_id="input_box",
- visible=False,
- )
-
multimodal_textbox = gr.MultimodalTextbox(
file_types=["image"],
show_label=False,
@@ -501,9 +503,6 @@ def build_side_by_side_vision_ui_anony(context: Context, random_questions=None):
placeholder="Enter your prompt or add image here",
elem_id="input_box",
)
- send_btn = gr.Button(
- value="Send", variant="primary", scale=0, visible=False, interactive=False
- )
with gr.Row() as button_row:
if random_questions:
@@ -555,25 +554,29 @@ def build_side_by_side_vision_ui_anony(context: Context, random_questions=None):
leftvote_btn.click(
leftvote_last_response,
states + model_selectors,
- model_selectors + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn],
+ model_selectors
+ + [multimodal_textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn],
)
rightvote_btn.click(
rightvote_last_response,
states + model_selectors,
- model_selectors + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn],
+ model_selectors
+ + [multimodal_textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn],
)
tie_btn.click(
tievote_last_response,
states + model_selectors,
- model_selectors + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn],
+ model_selectors
+ + [multimodal_textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn],
)
bothbad_btn.click(
bothbad_vote_last_response,
states + model_selectors,
- model_selectors + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn],
+ model_selectors
+ + [multimodal_textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn],
)
regenerate_btn.click(
- regenerate, states, states + chatbots + [textbox] + btn_list
+ regenerate, states, states + chatbots + [multimodal_textbox] + btn_list
).then(
bot_response_multi,
states + [temperature, top_p, max_output_tokens],
@@ -587,7 +590,7 @@ def build_side_by_side_vision_ui_anony(context: Context, random_questions=None):
states
+ chatbots
+ model_selectors
- + [multimodal_textbox, textbox, send_btn]
+ + [multimodal_textbox]
+ btn_list
+ [random_btn]
+ [slow_warning],
@@ -617,14 +620,6 @@ def build_side_by_side_vision_ui_anony(context: Context, random_questions=None):
multimodal_textbox.input(add_image, [multimodal_textbox], [imagebox]).then(
set_visible_image, [multimodal_textbox], [image_column]
- ).then(
- clear_history_example,
- None,
- states
- + chatbots
- + model_selectors
- + [multimodal_textbox, textbox, send_btn]
- + btn_list,
)
multimodal_textbox.submit(
@@ -632,7 +627,7 @@ def build_side_by_side_vision_ui_anony(context: Context, random_questions=None):
states + model_selectors + [multimodal_textbox, context_state],
states
+ chatbots
- + [multimodal_textbox, textbox, send_btn]
+ + [multimodal_textbox]
+ btn_list
+ [random_btn]
+ [slow_warning]
@@ -647,60 +642,11 @@ def build_side_by_side_vision_ui_anony(context: Context, random_questions=None):
btn_list,
)
- textbox.submit(
- add_text,
- states + model_selectors + [textbox, context_state],
- states
- + chatbots
- + [multimodal_textbox, textbox, send_btn]
- + btn_list
- + [random_btn]
- + [slow_warning]
- + [show_vote_buttons],
- ).then(
- bot_response_multi,
- states + [temperature, top_p, max_output_tokens],
- states + chatbots + btn_list,
- ).then(
- flash_buttons,
- [show_vote_buttons],
- btn_list,
- )
-
- send_btn.click(
- add_text,
- states + model_selectors + [textbox, context_state],
- states
- + chatbots
- + [multimodal_textbox, textbox, send_btn]
- + btn_list
- + [random_btn]
- + [slow_warning]
- + [show_vote_buttons],
- ).then(
- bot_response_multi,
- states + [temperature, top_p, max_output_tokens],
- states + chatbots + btn_list,
- ).then(
- flash_buttons,
- [show_vote_buttons],
- btn_list,
- )
-
if random_questions:
random_btn.click(
get_vqa_sample, # First, get the VQA sample
[], # Pass the path to the VQA samples
[multimodal_textbox, imagebox], # Outputs are textbox and imagebox
- ).then(set_visible_image, [multimodal_textbox], [image_column]).then(
- clear_history_example,
- None,
- states
- + chatbots
- + model_selectors
- + [multimodal_textbox, textbox, send_btn]
- + btn_list
- + [random_btn],
- )
+ ).then(set_visible_image, [multimodal_textbox], [image_column])
return states + model_selectors
diff --git a/fastchat/serve/gradio_block_arena_vision_named.py b/fastchat/serve/gradio_block_arena_vision_named.py
index 279eb391d..218eb344e 100644
--- a/fastchat/serve/gradio_block_arena_vision_named.py
+++ b/fastchat/serve/gradio_block_arena_vision_named.py
@@ -99,7 +99,7 @@ def clear_history_example(request: gr.Request):
return (
[None] * num_sides
+ [None] * num_sides
- + [enable_multimodal_keep_input, invisible_text, invisible_btn]
+ + [enable_multimodal_keep_input]
+ [invisible_btn] * 4
+ [disable_btn] * 2
)
@@ -184,7 +184,7 @@ def clear_history(request: gr.Request):
return (
[None] * num_sides
+ [None] * num_sides
- + [enable_multimodal_clear_input, invisible_text, invisible_btn]
+ + [enable_multimodal_clear_input]
+ [invisible_btn] * 4
+ [disable_btn] * 2
)
@@ -236,7 +236,7 @@ def add_text(
return (
states
+ [x.to_gradio_chatbot() for x in states]
- + [None, "", no_change_btn]
+ + [None]
+ [
no_change_btn,
]
@@ -253,7 +253,7 @@ def add_text(
images = convert_images_to_conversation_format(images)
# Use the first state to get the moderation response because this is based on user input so it is independent of the model
- moderation_image_input = images[0] if len(images) > 0 else None
+ moderation_image_input = images if len(images) > 0 else None
moderation_type_to_response_map = states[
0
].content_moderator.image_and_text_moderation_filter(
@@ -262,8 +262,18 @@ def add_text(
text_flagged, nsfw_flag, csam_flag = (
moderation_type_to_response_map["text_moderation"]["flagged"],
- moderation_type_to_response_map["nsfw_moderation"]["flagged"],
- moderation_type_to_response_map["csam_moderation"]["flagged"],
+ any(
+ [
+ response["flagged"]
+ for response in moderation_type_to_response_map["nsfw_moderation"]
+ ]
+ ),
+ any(
+ [
+ response["flagged"]
+ for response in moderation_type_to_response_map["csam_moderation"]
+ ]
+ ),
)
if csam_flag:
@@ -285,7 +295,7 @@ def add_text(
return (
states
+ [x.to_gradio_chatbot() for x in states]
- + [{"text": CONVERSATION_LIMIT_MSG}, "", no_change_btn]
+ + [{"text": CONVERSATION_LIMIT_MSG}]
+ [
no_change_btn,
]
@@ -296,14 +306,16 @@ def add_text(
logger.info(f"violate moderation. ip: {ip}. text: {text}")
gradio_chatbot_list = [x.to_gradio_chatbot() for x in states]
for i in range(num_sides):
- post_processed_text = _prepare_text_with_image(states[i], text, images)
+ post_processed_text = _prepare_text_with_image(
+ states[i], text, images, context
+ )
states[i].conv.append_message(states[i].conv.roles[0], post_processed_text)
states[i].skip_next = True
gr.Warning(MODERATION_MSG)
return (
states
+ gradio_chatbot_list
- + [None, "", no_change_btn]
+ + [None]
+ [
no_change_btn,
]
@@ -316,6 +328,7 @@ def add_text(
states[i],
text,
images,
+ context,
)
states[i].conv.append_message(states[i].conv.roles[0], post_processed_text)
states[i].conv.append_message(states[i].conv.roles[1], None)
@@ -324,7 +337,7 @@ def add_text(
return (
states
+ [x.to_gradio_chatbot() for x in states]
- + [disable_multimodal, visible_text, enable_btn]
+ + [None]
+ [
disable_btn,
]
@@ -417,17 +430,6 @@ def build_side_by_side_vision_ui_named(context: Context, random_questions=None):
)
with gr.Row():
- textbox = gr.Textbox(
- show_label=False,
- placeholder="👉 Enter your prompt and press ENTER",
- elem_id="input_box",
- visible=False,
- )
-
- send_btn = gr.Button(
- value="Send", variant="primary", scale=0, visible=False, interactive=False
- )
-
multimodal_textbox = gr.MultimodalTextbox(
file_types=["image"],
show_label=False,
@@ -486,25 +488,25 @@ def build_side_by_side_vision_ui_named(context: Context, random_questions=None):
leftvote_btn.click(
leftvote_last_response,
states + model_selectors,
- [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn],
+ [multimodal_textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn],
)
rightvote_btn.click(
rightvote_last_response,
states + model_selectors,
- [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn],
+ [multimodal_textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn],
)
tie_btn.click(
tievote_last_response,
states + model_selectors,
- [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn],
+ [multimodal_textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn],
)
bothbad_btn.click(
bothbad_vote_last_response,
states + model_selectors,
- [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn],
+ [multimodal_textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn],
)
regenerate_btn.click(
- regenerate, states, states + chatbots + [textbox] + btn_list
+ regenerate, states, states + chatbots + [multimodal_textbox] + btn_list
).then(
bot_response_multi,
states + [temperature, top_p, max_output_tokens],
@@ -515,7 +517,7 @@ def build_side_by_side_vision_ui_named(context: Context, random_questions=None):
clear_btn.click(
clear_history,
None,
- states + chatbots + [multimodal_textbox, textbox, send_btn] + btn_list,
+ states + chatbots + [multimodal_textbox] + btn_list,
)
share_js = """
@@ -544,45 +546,17 @@ def build_side_by_side_vision_ui_named(context: Context, random_questions=None):
model_selectors[i].change(
clear_history,
None,
- states + chatbots + [multimodal_textbox, textbox, send_btn] + btn_list,
+ states + chatbots + [multimodal_textbox] + btn_list,
).then(set_visible_image, [multimodal_textbox], [image_column])
multimodal_textbox.input(add_image, [multimodal_textbox], [imagebox]).then(
set_visible_image, [multimodal_textbox], [image_column]
- ).then(
- clear_history_example,
- None,
- states + chatbots + [multimodal_textbox, textbox, send_btn] + btn_list,
)
multimodal_textbox.submit(
add_text,
states + model_selectors + [multimodal_textbox, context_state],
- states + chatbots + [multimodal_textbox, textbox, send_btn] + btn_list,
- ).then(set_invisible_image, [], [image_column]).then(
- bot_response_multi,
- states + [temperature, top_p, max_output_tokens],
- states + chatbots + btn_list,
- ).then(
- flash_buttons, [], btn_list
- )
-
- textbox.submit(
- add_text,
- states + model_selectors + [textbox, context_state],
- states + chatbots + [multimodal_textbox, textbox, send_btn] + btn_list,
- ).then(set_invisible_image, [], [image_column]).then(
- bot_response_multi,
- states + [temperature, top_p, max_output_tokens],
- states + chatbots + btn_list,
- ).then(
- flash_buttons, [], btn_list
- )
-
- send_btn.click(
- add_text,
- states + model_selectors + [textbox, context_state],
- states + chatbots + [multimodal_textbox, textbox, send_btn] + btn_list,
+ states + chatbots + [multimodal_textbox] + btn_list,
).then(set_invisible_image, [], [image_column]).then(
bot_response_multi,
states + [temperature, top_p, max_output_tokens],
@@ -596,10 +570,6 @@ def build_side_by_side_vision_ui_named(context: Context, random_questions=None):
get_vqa_sample, # First, get the VQA sample
[], # Pass the path to the VQA samples
[multimodal_textbox, imagebox], # Outputs are textbox and imagebox
- ).then(set_visible_image, [multimodal_textbox], [image_column]).then(
- clear_history_example,
- None,
- states + chatbots + [multimodal_textbox, textbox, send_btn] + btn_list,
- )
+ ).then(set_visible_image, [multimodal_textbox], [image_column])
return states + model_selectors
diff --git a/fastchat/serve/gradio_global_state.py b/fastchat/serve/gradio_global_state.py
index e05022f18..de911985d 100644
--- a/fastchat/serve/gradio_global_state.py
+++ b/fastchat/serve/gradio_global_state.py
@@ -8,3 +8,4 @@ class Context:
all_text_models: List[str] = field(default_factory=list)
vision_models: List[str] = field(default_factory=list)
all_vision_models: List[str] = field(default_factory=list)
+ api_endpoint_info: dict = field(default_factory=dict)
diff --git a/fastchat/serve/gradio_web_server.py b/fastchat/serve/gradio_web_server.py
index c4cef2e19..46ff0c421 100644
--- a/fastchat/serve/gradio_web_server.py
+++ b/fastchat/serve/gradio_web_server.py
@@ -228,6 +228,11 @@ def get_model_list(controller_url, register_api_endpoint_file, vision_arena):
return visible_models, models
+def _get_api_endpoint_info():
+ global api_endpoint_info
+ return api_endpoint_info
+
+
def load_demo_single(context: Context, query_params):
# default to text models
models = context.text_models
diff --git a/fastchat/serve/gradio_web_server_multi.py b/fastchat/serve/gradio_web_server_multi.py
index ade8abc6b..f19fab6ec 100644
--- a/fastchat/serve/gradio_web_server_multi.py
+++ b/fastchat/serve/gradio_web_server_multi.py
@@ -41,6 +41,7 @@
get_model_list,
load_demo_single,
get_ip,
+ _get_api_endpoint_info,
)
from fastchat.serve.monitor.monitor import build_leaderboard_tab
from fastchat.utils import (
@@ -318,8 +319,15 @@ def build_demo(context: Context, elo_results_file: str, leaderboard_table_file):
args.register_api_endpoint_file,
vision_arena=True,
)
-
- context = Context(text_models, all_text_models, vision_models, all_vision_models)
+ api_endpoint_info = _get_api_endpoint_info()
+
+ context = Context(
+ text_models,
+ all_text_models,
+ vision_models,
+ all_vision_models,
+ api_endpoint_info,
+ )
# Set authorization credentials
auth = None
diff --git a/fastchat/serve/moderation/moderator.py b/fastchat/serve/moderation/moderator.py
index 96bb30ae0..efafdcd88 100644
--- a/fastchat/serve/moderation/moderator.py
+++ b/fastchat/serve/moderation/moderator.py
@@ -35,7 +35,7 @@ def _text_moderation_filter(self, text: str) -> bool:
raise NotImplementedError
def image_and_text_moderation_filter(
- self, image: Image, text: str
+ self, images: List[Image], text: str
) -> Dict[str, Dict[str, Union[str, Dict[str, float]]]]:
"""Function that detects whether image and text violate moderation policies.
@@ -65,8 +65,8 @@ def update_last_moderation_response(
class AzureAndOpenAIContentModerator(BaseContentModerator):
_NON_TOXIC_IMAGE_MODERATION_MAP = {
- "nsfw_moderation": {"flagged": False},
- "csam_moderation": {"flagged": False},
+ "nsfw_moderation": [{"flagged": False}],
+ "csam_moderation": [{"flagged": False}],
}
def __init__(self, use_remote_storage: bool = False):
@@ -121,25 +121,33 @@ def _image_moderation_provider(self, image_bytes: bytes, api_type: str) -> bool:
return moderation_response_map
- def image_moderation_filter(self, image: Image):
- print(f"moderating image")
+ def image_moderation_filter(self, images: List[Image]):
+ print(f"moderating images")
- image_bytes = base64.b64decode(image.base64_str)
+ images_moderation_response: Dict[
+ str, List[Dict[str, Union[str, Dict[str, float]]]]
+ ] = {
+ "nsfw_moderation": [],
+ "csam_moderation": [],
+ }
- nsfw_flagged_map = self._image_moderation_provider(image_bytes, "nsfw")
+ for image in images:
+ image_bytes = base64.b64decode(image.base64_str)
- if nsfw_flagged_map["flagged"]:
- csam_flagged_map = self._image_moderation_provider(image_bytes, "csam")
- else:
- csam_flagged_map = {"flagged": False}
+ nsfw_flagged_map = self._image_moderation_provider(image_bytes, "nsfw")
- self.nsfw_flagged = nsfw_flagged_map["flagged"]
- self.csam_flagged = csam_flagged_map["flagged"]
+ if nsfw_flagged_map["flagged"]:
+ csam_flagged_map = self._image_moderation_provider(image_bytes, "csam")
+ else:
+ csam_flagged_map = {"flagged": False}
- return {
- "nsfw_moderation": nsfw_flagged_map,
- "csam_moderation": csam_flagged_map,
- }
+ self.nsfw_flagged |= nsfw_flagged_map["flagged"]
+ self.csam_flagged |= csam_flagged_map["flagged"]
+
+ images_moderation_response["nsfw_moderation"].append(nsfw_flagged_map)
+ images_moderation_response["csam_moderation"].append(csam_flagged_map)
+
+ return images_moderation_response
def _openai_moderation_filter(
self, text: str, custom_thresholds: dict = None
@@ -203,6 +211,7 @@ def text_moderation_filter(
do_moderation = True
break
+ moderation_response_map = {"flagged": False}
if do_moderation:
moderation_response_map = self._openai_moderation_filter(
text, custom_thresholds
@@ -212,7 +221,7 @@ def text_moderation_filter(
return {"text_moderation": moderation_response_map}
def image_and_text_moderation_filter(
- self, image: Image, text: str, model_list: List[str], do_moderation=True
+ self, images: List[Image], text: str, model_list: List[str], do_moderation=True
) -> Dict[str, bool]:
"""Function that detects whether image and text violate moderation policies using the Azure and OpenAI moderation APIs.
@@ -247,8 +256,8 @@ def image_and_text_moderation_filter(
print("moderating text: ", text)
text_flagged_map = self.text_moderation_filter(text, model_list, do_moderation)
- if image is not None:
- image_flagged_map = self.image_moderation_filter(image)
+ if images:
+ image_flagged_map = self.image_moderation_filter(images)
else:
image_flagged_map = self._NON_TOXIC_IMAGE_MODERATION_MAP