Skip to content

Commit

Permalink
update frontend ui, replace if to while loop to support multiple sear…
Browse files Browse the repository at this point in the history
…ch trial, and make sure our code aligns with the format
  • Loading branch information
tsunghan-wu committed Oct 28, 2024
1 parent 56c0537 commit 27d5943
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 54 deletions.
1 change: 0 additions & 1 deletion fastchat/model/model_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,6 @@ def get_default_conv_template(self, model_path: str) -> Conversation:


class ReActAgentAdapter(BaseModelAdapter):

def match(self, model_path: str):
return "react-agent" in model_path.lower()

Expand Down
106 changes: 64 additions & 42 deletions fastchat/serve/gradio_web_server_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,7 @@ def bot_response(
if model_api_dict.get("agent-mode", False):
html_code = ' <span class="cursor"></span> '
conv.update_last_message(html_code)

yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
try:
while True:
Expand All @@ -548,57 +548,79 @@ def bot_response(
top_p,
max_new_tokens,
state,
)
)
data = {"text": ""}
conv.update_last_message("Thinking...")
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
for i, data in enumerate(stream_iter):
output = data["text"].strip()
system_conv.update_last_message(output)
parsed_response = parse_json_from_string(output)
break
# break
except json.JSONDecodeError as e:
print('JSONDecodeError: ', e)

if "action" in parsed_response:
conv.update_last_message(f"Web search with {parsed_response['action']['arguments']['key_words']} keywords...")
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
action = parsed_response["action"]
assert "web_search" == action["name"]
arguments = action["arguments"]
web_search_result = web_search(**arguments)
system_conv.append_message(system_conv.roles[1], f"Web search result: {web_search_result}")
system_conv.append_message(system_conv.roles[1], None)
conv.update_last_message(f'Web search result:\n\n{web_search_result}')
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5

# generate answer after web search
last_message = conv.messages[-1][1]
stream_iter = get_api_provider_stream_iter(
system_conv,
model_name,
model_api_dict,
temperature,
top_p,
max_new_tokens,
state,
print("JSONDecodeError: ", e)
last_message = None
maximum_action_steps = 5
current_action_steps = 0
while "action" in parsed_response:
current_action_steps += 1
if current_action_steps > maximum_action_steps:
break
conv.update_last_message(
f"Web search with {parsed_response['action']['arguments']['key_words']} keywords..."
)
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
action = parsed_response["action"]
assert "web_search" == action["name"]
arguments = action["arguments"]
web_search_result = web_search(**arguments)
system_conv.append_message(
system_conv.roles[1], f"Web search result: {web_search_result}"
)
system_conv.append_message(system_conv.roles[1], None)
conv.update_last_message(
f"Web search result: \n\n{web_search_result}"
)
data = {"text": ""}
for i, data in enumerate(stream_iter):
output = data["text"].strip()
system_conv.update_last_message(output)
conv.update_last_message(f"{last_message}\n\n{output}▌")
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
parsed_response = parse_json_from_string(output)

if "answer" in parsed_response:
conv.update_last_message(f"{last_message}\n\n{parsed_response['answer'].strip()}")
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5

elif "answer" in parsed_response:
conv.update_last_message(parsed_response["answer"].strip())

# generate answer after web search
last_message = conv.messages[-1][1]
stream_iter = get_api_provider_stream_iter(
system_conv,
model_name,
model_api_dict,
temperature,
top_p,
max_new_tokens,
state,
)
data = {"text": ""}
for i, data in enumerate(stream_iter):
output = data["text"].strip()
system_conv.update_last_message(output)
conv.update_last_message(f"{last_message}{output}▌")
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
parsed_response = parse_json_from_string(output)

if "answer" in parsed_response:
conv.update_last_message(
f"{last_message}{parsed_response['answer'].strip()}"
)
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
break

assert (
"answer" in parsed_response
), f"parsed_response: {parsed_response}"
if last_message is None:
conv.update_last_message(parsed_response["answer"].strip())
else:
conv.update_last_message(
f"{last_message}{parsed_response['answer'].strip()}"
)
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5

break

except requests.exceptions.RequestException as e:
conv.update_last_message(
f"{SERVER_ERROR_MSG}\n\n"
Expand Down Expand Up @@ -635,7 +657,7 @@ def bot_response(
top_p,
max_new_tokens,
state,
)
)
html_code = ' <span class="cursor"></span> '

# conv.update_last_message("▌")
Expand Down
29 changes: 18 additions & 11 deletions fastchat/tools/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

YOU_SEARCH_API_KEY = "YOUR API KEY"


def get_ai_snippets_for_query(query, num_web_results=1):
headers = {"X-API-Key": YOU_SEARCH_API_KEY}
params = {"query": query, "num_web_results": num_web_results}
Expand All @@ -11,19 +12,25 @@ def get_ai_snippets_for_query(query, num_web_results=1):
headers=headers,
).json()


def format_search_results(results):
formatted_results = ''
results = results['hits']
for result in results:
formatted_results += result['url'] + '\n'
formatted_results += result['title'] + '\n'
for snippet in result['snippets']:
formatted_results += snippet + '\n'
formatted_results += '--------------------------------\n'
formatted_results += '--------------------------------\n'
formatted_results = ""
results = results["hits"]
for idx, result in enumerate(results):
formatted_results += (
f"{idx+1}. [" + result["title"] + "](" + result["url"] + ")" + "\n"
)
if len(result["snippets"]) > 0:
formatted_results += "Descriptions: \n"
for snippet in result["snippets"]:
formatted_results += "- " + snippet + "\n"
formatted_results += "--------------------------------\n"
return formatted_results


def web_search(key_words, topk=1):
web_search_results = get_ai_snippets_for_query(query=key_words, num_web_results=topk)
web_search_results = get_ai_snippets_for_query(
query=key_words, num_web_results=topk
)
web_search_results = format_search_results(web_search_results)
return web_search_results
return web_search_results

0 comments on commit 27d5943

Please sign in to comment.