Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PY] fix: Chat Completion - Tools Fixes #1942

Merged
merged 4 commits into from
Aug 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions python/packages/ai/teams/ai/ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,13 +167,14 @@ async def run(
output = await self._actions[ActionTypes.DO_COMMAND].invoke(
context, state, command, command.action
)
loop = len(output) > 0
state.temp.action_outputs[command.action] = output

# Set output for action call
if command.action_id:
loop = True
state.temp.action_outputs[command.action_id] = output or ""
else:
loop = len(output) > 0
state.temp.action_outputs[command.action] = output
else:
output = await self._actions[ActionTypes.UNKNOWN_ACTION].invoke(
context, state, plan, command.action
Expand All @@ -190,7 +191,12 @@ async def run(
return False

state.temp.last_output = output
state.temp.input = output

if isinstance(command, PredictedDoCommand) and command.action_id:
state.delete("temp.input")
else:
state.temp.input = output

state.temp.input_files = []

if loop and self._options.allow_looping:
Expand Down
25 changes: 7 additions & 18 deletions python/packages/ai/teams/ai/clients/llm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from dataclasses import dataclass, field
from logging import Logger
from typing import Any, List, Optional
from typing import Any, List, Optional, Union

from botbuilder.core import TurnContext

Expand Down Expand Up @@ -90,21 +90,6 @@ def __init__(self, options: LLMClientOptions) -> None:

self._options = options

def add_action_output_to_history(self, memory: MemoryBase, id: str, results: str) -> None:
"""
Adds the result from an `action_call` to the history.

Args:
memory (MemoryBase): An interface for accessing state values.
id (str): Id of the action that was called.
results (str): Results returned by the action call.
"""
self._add_message_to_history(
memory=memory,
variable=self._options.history_variable,
message=Message(role="tool", action_call_id=id, content=results),
)

async def complete_prompt(
self,
context: TurnContext,
Expand Down Expand Up @@ -207,10 +192,14 @@ async def complete_prompt(
return PromptResponse(status="error", error=str(err))

def _add_message_to_history(
self, memory: MemoryBase, variable: str, message: Message[Any]
self, memory: MemoryBase, variable: str, messages: Union[Message[Any], List[Message[Any]]]
) -> None:

history: List[Message] = memory.get(variable) or []
history.append(message)
if isinstance(messages, list):
history.extend(messages)
else:
history.append(messages)

if len(history) > self._options.max_history_messages:
del history[0 : len(history) - self._options.max_history_messages]
Expand Down
11 changes: 10 additions & 1 deletion python/packages/ai/teams/ai/models/openai_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,13 +282,22 @@ async def complete_prompt(
)
)

input: Optional[Message] = None
input: Optional[Union[Message, List[Message]]] = None
last_message = len(res.output) - 1

# Skips the first message which is the prompt
if last_message > 0 and res.output[last_message].role != "assistant":
input = res.output[last_message]

# Add remaining parallel tool calls
if input.role == "tool":
first_message = len(res.output)
for msg in reversed(res.output):
if msg.action_calls:
break
first_message -= 1
input = res.output[first_message:]

return PromptResponse[str](
input=input,
message=Message(
Expand Down
6 changes: 3 additions & 3 deletions python/packages/ai/teams/ai/models/prompt_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Any, Generic, Literal, Optional, TypeVar
from typing import Any, Generic, List, Literal, Optional, TypeVar, Union

from ..prompts.message import Message

Expand All @@ -25,9 +25,9 @@ class PromptResponse(Generic[ContentT]):
Status of the prompt response.
"""

input: Optional[Message[Any]] = None
input: Optional[Union[Message[Any], List[Message[Any]]]] = None
"""
User input message sent to the model. `undefined` if no input was sent.
Input message sent to the model. `undefined` if no input was sent.
"""

message: Optional[Message[ContentT]] = None
Expand Down
15 changes: 0 additions & 15 deletions python/packages/ai/teams/ai/planners/action_planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from ..clients import LLMClient, LLMClientOptions
from ..models.prompt_completion_model import PromptCompletionModel
from ..models.prompt_response import PromptResponse
from ..prompts.message import Message
from ..prompts.prompt_functions import PromptFunctions
from ..prompts.prompt_manager import PromptManager
from ..prompts.prompt_template import PromptTemplate
Expand Down Expand Up @@ -142,8 +141,6 @@ async def complete_prompt(
)
)

self._add_action_outputs(memory, history_var, client)

return await client.complete_prompt(
context=context,
memory=memory,
Expand All @@ -152,18 +149,6 @@ async def complete_prompt(
template=template,
)

def _add_action_outputs(self, memory: MemoryBase, history_var: str, client: LLMClient) -> None:
history: List[Message] = memory.get(history_var) or []

if history and len(history) > 1:
# Submit tool outputs
action_outputs = memory.get("temp.action_outputs") or {}
action_calls = history[-1].action_calls or []

for action_call in action_calls:
output = action_outputs[action_call.id]
client.add_action_output_to_history(memory, action_call.id, output)

def add_semantic_function(
self, prompt: Union[str, PromptTemplate], _validator: Optional[PromptResponseValidator]
) -> "ActionPlanner":
Expand Down
2 changes: 2 additions & 0 deletions python/packages/ai/teams/ai/prompts/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
Licensed under the MIT License.
"""

from .action_output_message import ActionOutputMessage
from .assistant_message import AssistantMessage
from .augmentation_config import AugmentationConfig
from .completion_config import CompletionConfig
Expand Down Expand Up @@ -57,4 +58,5 @@
"SystemMessage",
"UserInputMessage",
"UserMessage",
"ActionOutputMessage",
]
79 changes: 79 additions & 0 deletions python/packages/ai/teams/ai/prompts/action_output_message.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
"""
Copyright (c) Microsoft Corporation. All rights reserved.
Licensed under the MIT License.
"""

from __future__ import annotations

from typing import List

from botbuilder.core import TurnContext

from ...state import MemoryBase
from ..tokenizers import Tokenizer
from .message import Message
from .prompt_functions import PromptFunctions
from .rendered_prompt_section import RenderedPromptSection
from .sections.prompt_section_base import PromptSectionBase


class ActionOutputMessage(PromptSectionBase):
"""
A section capable of rendering action outputs.
"""

_output_variable: str
_history_variable: str

def __init__(
self,
history_variable: str,
output_variable: str = "temp.action_outputs",
):
"""
Creates a new 'ActionOutputMessage' instance.

Args:
history_variable (str): Name of the conversation history.
output_variable (str, optional): Name of the action outputs.
Defaults to `action_outputs`.
"""
super().__init__(-1, True, "\n", "action: ")
self._output_variable = output_variable
self._history_variable = history_variable

async def render_as_messages(
self,
context: TurnContext,
memory: MemoryBase,
functions: PromptFunctions,
tokenizer: Tokenizer,
max_tokens: int,
) -> RenderedPromptSection[List[Message[str]]]:
"""
Renders the actions section as a list of messages.

Args:
context (TurnContext): Context for the current turn of conversation with the user.
memory (MemoryBase): An interface for accessing state values.
functions (PromptFunctions): Registry of functions that can be used by the section.
tokenizer (Tokenizer): Tokenizer to use when rendering the section.
max_tokens (int): Maximum number of tokens allowed to be rendered.

Returns:
RenderedPromptSection[List[Message]]: The rendered prompt section as a list of messages.
"""

history: List[Message] = memory.get(self._history_variable) or []
messages: List[Message] = []

if len(history) > 1:
action_outputs = memory.get(self._output_variable) or {}
action_calls = history[-1].action_calls or []

for action_call in action_calls:
output = action_outputs[action_call.id]
message = Message[str](role="tool", action_call_id=action_call.id, content=output)
messages.append(message)

return RenderedPromptSection(output=messages, length=len(messages), too_long=False)
11 changes: 11 additions & 0 deletions python/packages/ai/teams/ai/prompts/prompt_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from botbuilder.core import TurnContext

import teams.ai.augmentations
from teams.ai.prompts.action_output_message import ActionOutputMessage

from ...app_error import ApplicationError
from ...state import MemoryBase
Expand Down Expand Up @@ -322,6 +323,16 @@ async def get_prompt(self, name: str) -> PromptTemplate:
elif template_config.completion.include_input:
sections.append(UserMessage("{{$temp.input}}", self._options.max_input_tokens))

if (
template_config.augmentation
and template_config.augmentation.augmentation_type == "tools"
):
include_history = template_config.completion.include_history
history_var = (
f"conversation.{name}_history" if include_history else f"temp.{name}_history"
)
sections.append(ActionOutputMessage(history_variable=history_var))

template = PromptTemplate(
template_name, Prompt(sections), template_config, template_actions
)
Expand Down
18 changes: 0 additions & 18 deletions python/packages/ai/tests/ai/clients/test_llm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,24 +81,6 @@ def create_mock_context(
context.activity.from_property.id = user_id
return context

@mock.patch("openai.AsyncOpenAI", return_value=MockAsyncOpenAI)
async def test_add_action_output(self, mock_async_openai):
context = self.create_mock_context()
state = await TurnState[ConversationState, UserState, TempState].load(context)

model = OpenAIModel(OpenAIModelOptions(api_key="", default_model="model"))
client = LLMClient(LLMClientOptions(model))
state.set(client.options.history_variable, [Message(role="assistant", content="results")])
client.add_action_output_to_history(memory=state, id="123", results="results")

expected_history = [
Message(role="assistant", content="results"),
Message(role="tool", action_call_id="123", content="results"),
]

self.assertTrue(mock_async_openai.called)
self.assertEqual(state.get(client.options.history_variable), expected_history)

@mock.patch("openai.AsyncOpenAI", return_value=MockAsyncOpenAI)
async def test_complete_prompt_no_attempts(self, mock_async_openai):
context = self.create_mock_context()
Expand Down
67 changes: 66 additions & 1 deletion python/packages/ai/tests/ai/models/test_openai_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,13 @@
PromptTemplateConfig,
TextSection,
)
from teams.ai.prompts.action_output_message import ActionOutputMessage
from teams.ai.prompts.augmentation_config import AugmentationConfig
from teams.ai.prompts.message import ActionCall, ActionFunction
from teams.ai.prompts.message import ActionCall, ActionFunction, Message
from teams.ai.prompts.prompt import Prompt
from teams.ai.prompts.sections.conversation_history_section import (
ConversationHistorySection,
)
from teams.ai.tokenizers import GPTTokenizer
from teams.state import TurnState

Expand Down Expand Up @@ -277,6 +282,66 @@ async def test_should_be_success(self, mock_async_openai):
self.assertTrue(mock_async_openai.called)
self.assertEqual(res.status, "success")

@mock.patch("openai.AsyncOpenAI", return_value=MockAsyncOpenAI)
async def test_should_succeed_on_prev_tool_calls(self, mock_async_openai):
context = self.create_mock_context()
state = TurnState()
state.temp = {}
state.conversation = {}
state.set(
"conversation.default_history",
[
Message(role="user", content="Turn the lights on"),
Message(
role="assistant",
action_calls=[
ActionCall(
id="test_tool_1",
type="function",
function=ActionFunction(name="tool_one", arguments="{}"),
),
ActionCall(
id="test_tool_2",
type="function",
function=ActionFunction(name="tool_two", arguments="{}"),
),
],
),
],
)
state.set("temp.action_outputs", {"test_tool_1": "hello", "test_tool_2": "world"})
await state.load(context)

model = OpenAIModel(OpenAIModelOptions(api_key="", default_model="model"))
res = await model.complete_prompt(
context=context,
memory=state,
functions=cast(PromptFunctions, {}),
tokenizer=GPTTokenizer(),
template=PromptTemplate(
name="default",
prompt=Prompt(
sections=[
ConversationHistorySection("conversation.default_history"),
ActionOutputMessage("conversation.default_history"),
]
),
config=PromptTemplateConfig(
schema=1.0,
type="completion",
description="test",
completion=CompletionConfig(completion_type="chat"),
),
),
)

self.assertTrue(mock_async_openai.called)
self.assertEqual(res.status, "success")
if res.input and isinstance(res.input, list):
self.assertEqual(len(res.input), 2)
self.assertEqual(res.input[0].action_call_id, "test_tool_1")
self.assertEqual(res.input[1].action_call_id, "test_tool_2")

@mock.patch("openai.AsyncOpenAI", return_value=MockAsyncOpenAI)
async def test_wrong_augmentation_type(self, mock_async_openai):
context = self.create_mock_context()
Expand Down
Loading
Loading