Skip to content

Commit

Permalink
Add support for thinking LLMs directly in gr.ChatInterface (#10305)
Browse files Browse the repository at this point in the history
* ungroup thoughts from messages

* rename messagebox to thought

* refactor

* * add metadata typing
* group thoughts when nested

* tweaks

* tweak

* add changeset

* fix expanded rotation

* border radius

* update thought design

* move spinner

* prevent circular reference

* revert border removal

* css tweaks

* border tweak

* move chevron to the left

* tweak nesting logic

* thought group spacing

* update run.py

* icon changes

* format

* add changeset

* add nested thought demo

* changes

* changes

* changes

* add demo

* docs

* guide

* refactor styles and clean up logic

* revert demo change and and deeper nested thought to demo

* add optional duration to message types

* add nested thoughts story

* format

* guide

* change dropdown icon button

* remove requirement for id's in nested thoughts

* support markdown in thought title

* get thought content in copied value

* add funcs to utils

* move is_all_text

* remove comment

* notebook

* change bot padding

* changes

* changes

* changes

* panel css fix

* changes

* changes

* changes

* changes

* tweak thought content opacity

* more changes

* add changeset

* changes

* restore

* changes

* changes

* revert everythign

* revert everythign

* revert

* changes

* revert

* make changes to demo

* notebooks

* more docs

* format

* changes

* changes

* update demo

* fix typing issues

* chatbot

* document chatmessage helper class

* add changeset

* changes

* format

* docs

---------

Co-authored-by: Hannah <[email protected]>
Co-authored-by: gradio-pr-bot <[email protected]>
Co-authored-by: aliabd <[email protected]>
  • Loading branch information
4 people authored Jan 10, 2025
1 parent d2691e7 commit be40307
Show file tree
Hide file tree
Showing 16 changed files with 311 additions and 68 deletions.
7 changes: 7 additions & 0 deletions .changeset/rich-ducks-grow.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
"@gradio/chatbot": minor
"gradio": minor
"website": minor
---

feat:Add support for thinking LLMs directly in `gr.ChatInterface`
1 change: 1 addition & 0 deletions demo/chatinterface_nested_thoughts/run.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: chatinterface_nested_thoughts"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "from gradio import ChatMessage\n", "import time\n", "\n", "sleep_time = 0.1\n", "long_sleep_time = 1\n", "\n", "def generate_response(message, history):\n", " start_time = time.time()\n", " responses = [\n", " ChatMessage(\n", " content=\"In order to find the current weather in San Francisco, I will need to use my weather tool.\",\n", " )\n", " ]\n", " yield responses\n", " time.sleep(sleep_time)\n", "\n", " main_thought = ChatMessage(\n", " content=\"\",\n", " metadata={\"title\": \"Using Weather Tool\", \"id\": 1, \"status\": \"pending\"},\n", " )\n", "\n", " responses.append(main_thought)\n", "\n", " yield responses\n", " time.sleep(long_sleep_time)\n", " responses[-1].content = \"Will check: weather.com and sunny.org\"\n", " yield responses\n", " time.sleep(sleep_time)\n", " responses.append(\n", " ChatMessage(\n", " content=\"Received weather from weather.com.\",\n", " metadata={\"title\": \"Checking weather.com\", \"parent_id\": 1, \"id\": 2, \"duration\": 0.05},\n", " )\n", " )\n", " yield responses\n", "\n", " sunny_start_time = time.time()\n", " time.sleep(sleep_time)\n", " sunny_thought = ChatMessage(\n", " content=\"API Error when connecting to sunny.org \ud83d\udca5\",\n", " metadata={\"title\": \"Checking sunny.org\", \"parent_id\": 1, \"id\": 3, \"status\": \"pending\"},\n", " )\n", "\n", " responses.append(sunny_thought)\n", " yield responses\n", "\n", " time.sleep(sleep_time)\n", " responses.append(\n", " ChatMessage(\n", " content=\"Failed again\",\n", " metadata={\"title\": \"I will try again\", \"id\": 4, \"parent_id\": 3, \"duration\": 0.1},\n", "\n", " )\n", " )\n", " sunny_thought.metadata[\"status\"] = \"done\"\n", " sunny_thought.metadata[\"duration\"] = time.time() - sunny_start_time\n", "\n", " main_thought.metadata[\"status\"] = \"done\"\n", " main_thought.metadata[\"duration\"] = time.time() - start_time\n", "\n", " yield responses\n", "\n", " time.sleep(long_sleep_time)\n", "\n", " responses.append(\n", " ChatMessage(\n", " content=\"Based on the data only from weather.com, the current weather in San Francisco is 60 degrees and sunny.\",\n", " )\n", " )\n", " yield responses\n", "\n", "demo = gr.ChatInterface(\n", " generate_response,\n", " type=\"messages\",\n", " title=\"Nested Thoughts Chat Interface\",\n", " examples=[\"What is the weather in San Francisco right now?\"]\n", ")\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
81 changes: 81 additions & 0 deletions demo/chatinterface_nested_thoughts/run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import gradio as gr
from gradio import ChatMessage
import time

sleep_time = 0.1
long_sleep_time = 1

def generate_response(message, history):
start_time = time.time()
responses = [
ChatMessage(
content="In order to find the current weather in San Francisco, I will need to use my weather tool.",
)
]
yield responses
time.sleep(sleep_time)

main_thought = ChatMessage(
content="",
metadata={"title": "Using Weather Tool", "id": 1, "status": "pending"},
)

responses.append(main_thought)

yield responses
time.sleep(long_sleep_time)
responses[-1].content = "Will check: weather.com and sunny.org"
yield responses
time.sleep(sleep_time)
responses.append(
ChatMessage(
content="Received weather from weather.com.",
metadata={"title": "Checking weather.com", "parent_id": 1, "id": 2, "duration": 0.05},
)
)
yield responses

sunny_start_time = time.time()
time.sleep(sleep_time)
sunny_thought = ChatMessage(
content="API Error when connecting to sunny.org 💥",
metadata={"title": "Checking sunny.org", "parent_id": 1, "id": 3, "status": "pending"},
)

responses.append(sunny_thought)
yield responses

time.sleep(sleep_time)
responses.append(
ChatMessage(
content="Failed again",
metadata={"title": "I will try again", "id": 4, "parent_id": 3, "duration": 0.1},

)
)
sunny_thought.metadata["status"] = "done"
sunny_thought.metadata["duration"] = time.time() - sunny_start_time

main_thought.metadata["status"] = "done"
main_thought.metadata["duration"] = time.time() - start_time

yield responses

time.sleep(long_sleep_time)

responses.append(
ChatMessage(
content="Based on the data only from weather.com, the current weather in San Francisco is 60 degrees and sunny.",
)
)
yield responses

demo = gr.ChatInterface(
generate_response,
type="messages",
title="Nested Thoughts Chat Interface",
examples=["What is the weather in San Francisco right now?"]
)

if __name__ == "__main__":
demo.launch()
2 changes: 1 addition & 1 deletion demo/chatinterface_options/run.ipynb
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: chatinterface_options"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import random\n", "\n", "example_code = \"\"\"\n", "Here's an example Python lambda function:\n", "\n", "lambda x: x + {}\n", "\n", "Is this correct?\n", "\"\"\"\n", "\n", "def chat(message, history):\n", " if message == \"Yes, that's correct.\":\n", " return \"Great!\"\n", " else:\n", " return {\n", " \"role\": \"assistant\",\n", " \"content\": example_code.format(random.randint(1, 100)),\n", " \"options\": [\n", " {\"value\": \"Yes, that's correct.\", \"label\": \"Yes\"},\n", " {\"value\": \"No\"}\n", " ]\n", " }\n", "\n", "demo = gr.ChatInterface(\n", " chat,\n", " type=\"messages\",\n", " examples=[\"Write an example Python lambda function.\"]\n", ")\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: chatinterface_options"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import random\n", "\n", "example_code = \"\"\"\n", "Here's an example Python lambda function:\n", "\n", "lambda x: x + {}\n", "\n", "Is this correct?\n", "\"\"\"\n", "\n", "def chat(message, history):\n", " if message == \"Yes, that's correct.\":\n", " return \"Great!\"\n", " else:\n", " return gr.ChatMessage(\n", " content=example_code.format(random.randint(1, 100)),\n", " options=[\n", " {\"value\": \"Yes, that's correct.\", \"label\": \"Yes\"},\n", " {\"value\": \"No\"}\n", " ]\n", " )\n", "\n", "demo = gr.ChatInterface(\n", " chat,\n", " type=\"messages\",\n", " examples=[\"Write an example Python lambda function.\"]\n", ")\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
11 changes: 5 additions & 6 deletions demo/chatinterface_options/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,13 @@ def chat(message, history):
if message == "Yes, that's correct.":
return "Great!"
else:
return {
"role": "assistant",
"content": example_code.format(random.randint(1, 100)),
"options": [
return gr.ChatMessage(
content=example_code.format(random.randint(1, 100)),
options=[
{"value": "Yes, that's correct.", "label": "Yes"},
{"value": "No"}
]
}
]
)

demo = gr.ChatInterface(
chat,
Expand Down
1 change: 1 addition & 0 deletions demo/chatinterface_thoughts/run.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: chatinterface_thoughts"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "from gradio import ChatMessage\n", "import time\n", "\n", "sleep_time = 0.5\n", "\n", "def simulate_thinking_chat(message, history):\n", " start_time = time.time()\n", " response = ChatMessage(\n", " content=\"\",\n", " metadata={\"title\": \"_Thinking_ step-by-step\", \"id\": 0, \"status\": \"pending\"}\n", " )\n", " yield response\n", "\n", " thoughts = [\n", " \"First, I need to understand the core aspects of the query...\",\n", " \"Now, considering the broader context and implications...\",\n", " \"Analyzing potential approaches to formulate a comprehensive answer...\",\n", " \"Finally, structuring the response for clarity and completeness...\"\n", " ]\n", "\n", " accumulated_thoughts = \"\"\n", " for thought in thoughts:\n", " time.sleep(sleep_time)\n", " accumulated_thoughts += f\"- {thought}\\n\\n\"\n", " response.content = accumulated_thoughts.strip()\n", " yield response\n", "\n", " response.metadata[\"status\"] = \"done\"\n", " response.metadata[\"duration\"] = time.time() - start_time\n", " yield response\n", "\n", " response = [\n", " response,\n", " ChatMessage(\n", " content=\"Based on my thoughts and analysis above, my response is: This dummy repro shows how thoughts of a thinking LLM can be progressively shown before providing its final answer.\"\n", " )\n", " ]\n", " yield response\n", "\n", "\n", "demo = gr.ChatInterface(\n", " simulate_thinking_chat,\n", " title=\"Thinking LLM Chat Interface \ud83e\udd14\",\n", " type=\"messages\",\n", ")\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
49 changes: 49 additions & 0 deletions demo/chatinterface_thoughts/run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import gradio as gr
from gradio import ChatMessage
import time

sleep_time = 0.5

def simulate_thinking_chat(message, history):
start_time = time.time()
response = ChatMessage(
content="",
metadata={"title": "_Thinking_ step-by-step", "id": 0, "status": "pending"}
)
yield response

thoughts = [
"First, I need to understand the core aspects of the query...",
"Now, considering the broader context and implications...",
"Analyzing potential approaches to formulate a comprehensive answer...",
"Finally, structuring the response for clarity and completeness..."
]

accumulated_thoughts = ""
for thought in thoughts:
time.sleep(sleep_time)
accumulated_thoughts += f"- {thought}\n\n"
response.content = accumulated_thoughts.strip()
yield response

response.metadata["status"] = "done"
response.metadata["duration"] = time.time() - start_time
yield response

response = [
response,
ChatMessage(
content="Based on my thoughts and analysis above, my response is: This dummy repro shows how thoughts of a thinking LLM can be progressively shown before providing its final answer."
)
]
yield response


demo = gr.ChatInterface(
simulate_thinking_chat,
title="Thinking LLM Chat Interface 🤔",
type="messages",
)

if __name__ == "__main__":
demo.launch()
7 changes: 7 additions & 0 deletions gradio/chat_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import builtins
import copy
import dataclasses
import inspect
import os
import warnings
Expand All @@ -32,6 +33,7 @@
get_component_instance,
)
from gradio.components.chatbot import (
ChatMessage,
ExampleMessage,
Message,
MessageDict,
Expand Down Expand Up @@ -808,6 +810,11 @@ def _message_as_message_dict(
for msg in message:
if isinstance(msg, Message):
message_dicts.append(msg.model_dump())
elif isinstance(msg, ChatMessage):
msg.role = role
message_dicts.append(
dataclasses.asdict(msg, dict_factory=utils.dict_factory)
)
elif isinstance(msg, (str, Component)):
message_dicts.append({"role": role, "content": msg})
elif (
Expand Down
26 changes: 21 additions & 5 deletions gradio/components/chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ class MetadataDict(TypedDict):
title: Union[str, None]
id: NotRequired[int | str]
parent_id: NotRequired[int | str]
duration: NotRequired[float]
status: NotRequired[Literal["pending", "done"]]


class Option(TypedDict):
Expand All @@ -59,7 +61,6 @@ class MessageDict(TypedDict):
role: Literal["user", "assistant", "system"]
metadata: NotRequired[MetadataDict]
options: NotRequired[list[Option]]
duration: NotRequired[int]


class FileMessage(GradioModel):
Expand Down Expand Up @@ -87,14 +88,21 @@ class Metadata(GradioModel):
title: Optional[str] = None
id: Optional[int | str] = None
parent_id: Optional[int | str] = None
duration: Optional[float] = None
status: Optional[Literal["pending", "done"]] = None

def __setitem__(self, key: str, value: Any) -> None:
setattr(self, key, value)

def __getitem__(self, key: str) -> Any:
return getattr(self, key)


class Message(GradioModel):
role: str
metadata: Metadata = Field(default_factory=Metadata)
content: Union[str, FileMessage, ComponentMessage]
options: Optional[list[Option]] = None
duration: Optional[int] = None


class ExampleMessage(TypedDict):
Expand All @@ -110,13 +118,22 @@ class ExampleMessage(TypedDict):
] # list of file paths or URLs to be added to chatbot when example is clicked


@document()
@dataclass
class ChatMessage:
role: Literal["user", "assistant", "system"]
"""
A dataclass to represent a message in the Chatbot component (type="messages").
Parameters:
content: The content of the message. Can be a string or a Gradio component.
role: The role of the message, which determines the alignment of the message in the chatbot. Can be "user", "assistant", or "system". Defaults to "assistant".
metadata: The metadata of the message, which is used to display intermediate thoughts / tool usage. Should be a dictionary with the following keys: "title" (required to display the thought), and optionally: "id" and "parent_id" (to nest thoughts), "duration" (to display the duration of the thought), "status" (to display the status of the thought).
options: The options of the message. A list of Option objects, which are dictionaries with the following keys: "label" (the text to display in the option), and optionally "value" (the value to return when the option is selected if different from the label).
"""

content: str | FileData | Component | FileDataDict | tuple | list
role: Literal["user", "assistant", "system"] = "assistant"
metadata: MetadataDict | Metadata = field(default_factory=Metadata)
options: Optional[list[Option]] = None
duration: Optional[int] = None


class ChatbotDataMessages(GradioRootModel):
Expand Down Expand Up @@ -545,7 +562,6 @@ def _postprocess_message_messages(
content=message.content, # type: ignore
metadata=message.metadata, # type: ignore
options=message.options,
duration=message.duration,
)
elif isinstance(message, Message):
return message
Expand Down
2 changes: 1 addition & 1 deletion gradio/monitoring_dashboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def gen_plot(start, end, selected_fn):
if selected_fn != "All":
df = df[df["function"] == selected_fn]
df = df[(df["time"] >= start) & (df["time"] <= end)]
df["time"] = pd.to_datetime(df["time"], unit="s")
df["time"] = pd.to_datetime(df["time"], unit="s") # type: ignore

unique_users = len(df["session_hash"].unique()) # type: ignore
total_requests = len(df)
Expand Down
13 changes: 13 additions & 0 deletions gradio/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1609,3 +1609,16 @@ def get_icon_path(icon_name: str) -> str:
set_static_paths(icon_path)
return icon_path
raise ValueError(f"Icon file not found: {icon_name}")


def dict_factory(items):
"""
A utility function to convert a dataclass that includes pydantic fields to a dictionary.
"""
d = {}
for key, value in items:
if hasattr(value, "model_dump"):
d[key] = value.model_dump()
else:
d[key] = value
return d
Loading

0 comments on commit be40307

Please sign in to comment.