Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Towards multi-image #3510

Open
wants to merge 5 commits into
base: moderation-log
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions fastchat/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,12 +365,16 @@ def to_gradio_chatbot(self):
if i % 2 == 0:
if type(msg) is tuple:
msg, images = msg
image = images[0] # Only one image on gradio at one time
if image.image_format == ImageFormat.URL:
img_str = f'<img src="{image.url}" alt="user upload image" />'
elif image.image_format == ImageFormat.BYTES:
img_str = f'<img src="data:image/{image.filetype};base64,{image.base64_str}" alt="user upload image" />'
msg = img_str + msg.replace("<image>\n", "").strip()
combined_image_str = ""
for image in images:
if image.image_format == ImageFormat.URL:
img_str = (
f'<img src="{image.url}" alt="user upload image" />'
)
elif image.image_format == ImageFormat.BYTES:
img_str = f'<img src="data:image/{image.filetype};base64,{image.base64_str}" alt="user upload image" />'
combined_image_str += img_str
msg = combined_image_str + msg.replace("<image>\n", "").strip()

ret.append([msg, None])
else:
Expand Down
144 changes: 62 additions & 82 deletions fastchat/serve/gradio_block_arena_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,6 @@ def set_visible_image(textbox):
images = textbox["files"]
if len(images) == 0:
return invisible_image_column
elif len(images) > 1:
gr.Warning(
"We only support single image conversations. Please start a new round if you would like to chat using this image."
)

return visible_image_column

Expand Down Expand Up @@ -155,32 +151,47 @@ def clear_history(request: gr.Request):
ip = get_ip(request)
logger.info(f"clear_history. ip: {ip}")
state = None
return (state, [], enable_multimodal_clear_input, invisible_text, invisible_btn) + (
disable_btn,
) * 5
return (state, [], enable_multimodal_clear_input) + (disable_btn,) * 5


def clear_history_example(request: gr.Request):
ip = get_ip(request)
logger.info(f"clear_history_example. ip: {ip}")
state = None
return (state, [], enable_multimodal_keep_input, invisible_text, invisible_btn) + (
disable_btn,
) * 5
return (state, [], enable_multimodal_keep_input) + (disable_btn,) * 5


# TODO(Chris): At some point, we would like this to be a live-reporting feature.
def report_csam_image(state, image):
pass


def _prepare_text_with_image(state, text, images):
def _prepare_text_with_image(
state: State, text: str, images: List[Image], context: Context
):
if len(images) > 0:
if len(state.conv.get_images()) > 0:
# reset convo with new image
state.conv = get_conversation_template(state.model_name)
model_supports_multi_image = context.api_endpoint_info[state.model_name].get(
"multi_image", False
)
num_previous_images = len(state.conv.get_images())
images_interleaved_with_text_exists_but_model_does_not_support = (
num_previous_images > 0 and not model_supports_multi_image
)
multiple_image_one_turn_but_model_does_not_support = (
len(images) > 1 and not model_supports_multi_image
)
if images_interleaved_with_text_exists_but_model_does_not_support:
gr.Warning(
f"The model does not support interleaved image/text. We only use the very first image."
)
return text
elif multiple_image_one_turn_but_model_does_not_support:
gr.Warning(
f"The model does not support multiple images. Only the first image will be used."
)
return text, [images[0]]

text = text, [images[0]]
text = text, images

return text

Expand All @@ -192,9 +203,10 @@ def convert_images_to_conversation_format(images):
MAX_NSFW_ENDPOINT_IMAGE_SIZE_IN_MB = 5 / 1.5
conv_images = []
if len(images) > 0:
conv_image = Image(url=images[0])
conv_image.to_conversation_format(MAX_NSFW_ENDPOINT_IMAGE_SIZE_IN_MB)
conv_images.append(conv_image)
for image in images:
conv_image = Image(url=image)
conv_image.to_conversation_format(MAX_NSFW_ENDPOINT_IMAGE_SIZE_IN_MB)
conv_images.append(conv_image)

return conv_images

Expand Down Expand Up @@ -223,27 +235,34 @@ def add_text(state, model_selector, chat_input, context: Context, request: gr.Re

if len(text) <= 0:
state.skip_next = True
return (state, state.to_gradio_chatbot(), None, "", no_change_btn) + (
no_change_btn,
) * 5
return (state, state.to_gradio_chatbot(), None) + (no_change_btn,) * 5

all_conv_text = state.conv.get_prompt()
all_conv_text = all_conv_text[-2000:] + "\nuser: " + text

images = convert_images_to_conversation_format(images)

# Use the first state to get the moderation response because this is based on user input so it is independent of the model
moderation_image_input = images[0] if len(images) > 0 else None
moderation_type_to_response_map = (
state.content_moderator.image_and_text_moderation_filter(
moderation_image_input, text, [state.model_name], do_moderation=False
images, text, [state.model_name], do_moderation=False
)
)

text_flagged, nsfw_flag, csam_flag = (
moderation_type_to_response_map["text_moderation"]["flagged"],
moderation_type_to_response_map["nsfw_moderation"]["flagged"],
moderation_type_to_response_map["csam_moderation"]["flagged"],
any(
[
response["flagged"]
for response in moderation_type_to_response_map["nsfw_moderation"]
]
),
any(
[
response["flagged"]
for response in moderation_type_to_response_map["csam_moderation"]
]
),
)

if csam_flag:
Expand All @@ -254,13 +273,15 @@ def add_text(state, model_selector, chat_input, context: Context, request: gr.Re
if text_flagged or nsfw_flag:
logger.info(f"violate moderation. ip: {ip}. text: {text}")
gradio_chatbot_before_user_input = state.to_gradio_chatbot()
post_processed_text = _prepare_text_with_image(state, text, images)
post_processed_text = _prepare_text_with_image(state, text, images, context)
state.conv.append_message(state.conv.roles[0], post_processed_text)
state.skip_next = True
gr.Warning(MODERATION_MSG)
return (state, gradio_chatbot_before_user_input, None, "", no_change_btn) + (
no_change_btn,
) * 5
return (
state,
gradio_chatbot_before_user_input,
None,
) + (no_change_btn,) * 5

if (len(state.conv.messages) - state.conv.offset) // 2 >= CONVERSATION_TURN_LIMIT:
logger.info(f"conversation turn limit. ip: {ip}. text: {text}")
Expand All @@ -269,20 +290,16 @@ def add_text(state, model_selector, chat_input, context: Context, request: gr.Re
state,
state.to_gradio_chatbot(),
{"text": CONVERSATION_LIMIT_MSG},
"",
no_change_btn,
) + (no_change_btn,) * 5

text = text[:INPUT_CHAR_LEN_LIMIT] # Hard cut-off
text = _prepare_text_with_image(state, text, images)
text = _prepare_text_with_image(state, text, images, context)
state.conv.append_message(state.conv.roles[0], text)
state.conv.append_message(state.conv.roles[1], None)
return (
state,
state.to_gradio_chatbot(),
disable_multimodal,
visible_text,
enable_btn,
None,
) + (disable_btn,) * 5


Expand Down Expand Up @@ -344,17 +361,6 @@ def build_single_vision_language_model_ui(
)

with gr.Row():
textbox = gr.Textbox(
show_label=False,
placeholder="👉 Enter your prompt and press ENTER",
elem_id="input_box",
visible=False,
)

send_btn = gr.Button(
value="Send", variant="primary", scale=0, visible=False, interactive=False
)

multimodal_textbox = gr.MultimodalTextbox(
file_types=["image"],
show_label=False,
Expand Down Expand Up @@ -409,67 +415,45 @@ def build_single_vision_language_model_ui(
upvote_btn.click(
upvote_last_response,
[state, model_selector],
[textbox, upvote_btn, downvote_btn, flag_btn],
[multimodal_textbox, upvote_btn, downvote_btn, flag_btn],
)
downvote_btn.click(
downvote_last_response,
[state, model_selector],
[textbox, upvote_btn, downvote_btn, flag_btn],
[multimodal_textbox, upvote_btn, downvote_btn, flag_btn],
)
flag_btn.click(
flag_last_response,
[state, model_selector],
[textbox, upvote_btn, downvote_btn, flag_btn],
[multimodal_textbox, upvote_btn, downvote_btn, flag_btn],
)
regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then(
regenerate_btn.click(
regenerate, state, [state, chatbot, multimodal_textbox] + btn_list
).then(
bot_response,
[state, temperature, top_p, max_output_tokens],
[state, chatbot] + btn_list,
)
clear_btn.click(
clear_history,
None,
[state, chatbot, multimodal_textbox, textbox, send_btn] + btn_list,
[state, chatbot, multimodal_textbox] + btn_list,
)

model_selector.change(
clear_history,
None,
[state, chatbot, multimodal_textbox, textbox, send_btn] + btn_list,
[state, chatbot, multimodal_textbox] + btn_list,
).then(set_visible_image, [multimodal_textbox], [image_column])

multimodal_textbox.input(add_image, [multimodal_textbox], [imagebox]).then(
set_visible_image, [multimodal_textbox], [image_column]
).then(
clear_history_example,
None,
[state, chatbot, multimodal_textbox, textbox, send_btn] + btn_list,
)

multimodal_textbox.submit(
add_text,
[state, model_selector, multimodal_textbox, context_state],
[state, chatbot, multimodal_textbox, textbox, send_btn] + btn_list,
).then(set_invisible_image, [], [image_column]).then(
bot_response,
[state, temperature, top_p, max_output_tokens],
[state, chatbot] + btn_list,
)

textbox.submit(
add_text,
[state, model_selector, textbox, context_state],
[state, chatbot, multimodal_textbox, textbox, send_btn] + btn_list,
).then(set_invisible_image, [], [image_column]).then(
bot_response,
[state, temperature, top_p, max_output_tokens],
[state, chatbot] + btn_list,
)

send_btn.click(
add_text,
[state, model_selector, textbox, context_state],
[state, chatbot, multimodal_textbox, textbox, send_btn] + btn_list,
[state, chatbot, multimodal_textbox] + btn_list,
).then(set_invisible_image, [], [image_column]).then(
bot_response,
[state, temperature, top_p, max_output_tokens],
Expand All @@ -481,10 +465,6 @@ def build_single_vision_language_model_ui(
get_vqa_sample, # First, get the VQA sample
[], # Pass the path to the VQA samples
[multimodal_textbox, imagebox], # Outputs are textbox and imagebox
).then(set_visible_image, [multimodal_textbox], [image_column]).then(
clear_history_example,
None,
[state, chatbot, multimodal_textbox, textbox, send_btn] + btn_list,
)
).then(set_visible_image, [multimodal_textbox], [image_column])

return [state, model_selector]
Loading
Loading