Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PY] fix: Map tokens config to max_tokens when non-o1 model is used. #2151

Merged
merged 6 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 72 additions & 66 deletions python/packages/ai/teams/ai/models/openai_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,77 +225,15 @@ async def complete_prompt(
if self._options.logger is not None:
self._options.logger.debug(f"PROMPT:\n{res.output}")

messages: List[chat.ChatCompletionMessageParam] = []

for msg in res.output:
param: Union[
chat.ChatCompletionUserMessageParam,
chat.ChatCompletionAssistantMessageParam,
chat.ChatCompletionSystemMessageParam,
chat.ChatCompletionToolMessageParam,
] = chat.ChatCompletionUserMessageParam(
role="user",
content=msg.content if msg.content is not None else "",
)

if msg.name:
setattr(param, "name", msg.name)

if msg.role == "assistant":
param = chat.ChatCompletionAssistantMessageParam(
role="assistant",
content=msg.content if msg.content is not None else "",
)

tool_call_params: List[chat.ChatCompletionMessageToolCallParam] = []

if msg.action_calls and len(msg.action_calls) > 0:
for tool_call in msg.action_calls:
tool_call_params.append(
chat.ChatCompletionMessageToolCallParam(
id=tool_call.id,
function=Function(
name=tool_call.function.name,
arguments=tool_call.function.arguments,
),
type=tool_call.type,
)
)
param["content"] = None
param["tool_calls"] = tool_call_params

if msg.name:
param["name"] = msg.name

elif msg.role == "tool":
param = chat.ChatCompletionToolMessageParam(
role="tool",
tool_call_id=msg.action_call_id if msg.action_call_id else "",
content=msg.content if msg.content else "",
)
elif msg.role == "system":
# o1 models do not support system messages
if is_o1_model:
param = chat.ChatCompletionUserMessageParam(
role="user",
content=msg.content if msg.content is not None else "",
)
else:
param = chat.ChatCompletionSystemMessageParam(
role="system",
content=msg.content if msg.content is not None else "",
)

if msg.name:
param["name"] = msg.name

messages.append(param)
messages: List[chat.ChatCompletionMessageParam]
messages = self._map_messages(res.output, is_o1_model)

try:
extra_body = {}
if template.config.completion.data_sources is not None:
extra_body["data_sources"] = template.config.completion.data_sources

max_tokens = template.config.completion.max_tokens
completion = await self._client.chat.completions.create(
messages=messages,
model=model,
Expand All @@ -305,7 +243,8 @@ async def complete_prompt(
frequency_penalty=template.config.completion.frequency_penalty,
top_p=template.config.completion.top_p if not is_o1_model else 1,
temperature=template.config.completion.temperature if not is_o1_model else 1,
max_completion_tokens=template.config.completion.max_tokens,
max_tokens=max_tokens if not is_o1_model else NOT_GIVEN,
max_completion_tokens=max_tokens if is_o1_model else NOT_GIVEN,
tools=tools if len(tools) > 0 else NOT_GIVEN,
tool_choice=tool_choice if len(tools) > 0 else NOT_GIVEN,
parallel_tool_calls=parallel_tool_calls if len(tools) > 0 else NOT_GIVEN,
Expand Down Expand Up @@ -436,3 +375,70 @@ async def complete_prompt(
status of {err.code}: {err.message}
""",
)

def _map_messages(self, msgs: List[Message], is_o1_model: bool):
output = []
for msg in msgs:
param: Union[
chat.ChatCompletionUserMessageParam,
chat.ChatCompletionAssistantMessageParam,
chat.ChatCompletionSystemMessageParam,
chat.ChatCompletionToolMessageParam,
] = chat.ChatCompletionUserMessageParam(
role="user",
content=msg.content if msg.content is not None else "",
)

if msg.name:
setattr(param, "name", msg.name)

if msg.role == "assistant":
param = chat.ChatCompletionAssistantMessageParam(
role="assistant",
content=msg.content if msg.content is not None else "",
)

tool_call_params: List[chat.ChatCompletionMessageToolCallParam] = []

if msg.action_calls and len(msg.action_calls) > 0:
for tool_call in msg.action_calls:
tool_call_params.append(
chat.ChatCompletionMessageToolCallParam(
id=tool_call.id,
function=Function(
name=tool_call.function.name,
arguments=tool_call.function.arguments,
),
type=tool_call.type,
)
)
param["content"] = None
param["tool_calls"] = tool_call_params

if msg.name:
param["name"] = msg.name

elif msg.role == "tool":
param = chat.ChatCompletionToolMessageParam(
role="tool",
tool_call_id=msg.action_call_id if msg.action_call_id else "",
content=msg.content if msg.content else "",
)
elif msg.role == "system":
# o1 models do not support system messages
if is_o1_model:
param = chat.ChatCompletionUserMessageParam(
role="user",
content=msg.content if msg.content is not None else "",
)
else:
param = chat.ChatCompletionSystemMessageParam(
role="system",
content=msg.content if msg.content is not None else "",
)

if msg.name:
param["name"] = msg.name

output.append(param)
return output
88 changes: 78 additions & 10 deletions python/packages/ai/tests/ai/models/test_openai_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,19 +100,18 @@ class MockAsyncCompletions:
should_error = False
has_tool_call = False
has_tool_calls = False
is_o1_model = False
messages = []
create_params = None

def __init__(
self, should_error=False, has_tool_call=False, has_tool_calls=False, is_o1_model=False
self, should_error=False, has_tool_call=False, has_tool_calls=False
) -> None:
self.should_error = should_error
self.has_tool_call = has_tool_call
self.has_tool_calls = has_tool_calls
self.is_o1_model = is_o1_model
self.messages = []

async def create(self, **kwargs) -> chat.ChatCompletion:
self.create_params = kwargs

if self.should_error:
raise openai.BadRequestError(
"bad request",
Expand All @@ -126,9 +125,6 @@ async def create(self, **kwargs) -> chat.ChatCompletion:
if self.has_tool_calls:
return await self.handle_tool_calls(**kwargs)

if self.is_o1_model:
self.messages = kwargs["messages"]

return chat.ChatCompletion(
id="",
choices=[
Expand Down Expand Up @@ -294,7 +290,6 @@ async def test_should_be_success(self, mock_async_openai):

@mock.patch("openai.AsyncOpenAI", return_value=MockAsyncOpenAI)
async def test_o1_model_should_use_user_message_over_system_message(self, mock_async_openai):
mock_async_openai.return_value.chat.completions.is_o1_model = True
context = self.create_mock_context()
state = TurnState()
state.temp = {}
Expand All @@ -319,8 +314,81 @@ async def test_o1_model_should_use_user_message_over_system_message(self, mock_a

self.assertTrue(mock_async_openai.called)
self.assertEqual(res.status, "success")
create_params = mock_async_openai.return_value.chat.completions.create_params
self.assertEqual(
create_params["messages"][0]["role"], "user"
)

@mock.patch("openai.AsyncOpenAI", return_value=MockAsyncOpenAI)
async def test_o1_model_should_use_max_completion_tokens_param(self, mock_async_openai):
context = self.create_mock_context()
state = TurnState()
state.temp = {}
state.conversation = {}
model = OpenAIModel(OpenAIModelOptions(api_key="", default_model="o1-"))
completion = CompletionConfig(completion_type="chat")
completion.max_tokens = 1000
res = await model.complete_prompt(
context=context,
memory=state,
functions=cast(PromptFunctions, {}),
tokenizer=GPTTokenizer(),
template=PromptTemplate(
name="default",
prompt=Prompt(sections=[TemplateSection("prompt text", "system")]),
config=PromptTemplateConfig(
schema=1.0,
type="completion",
description="test",
completion=completion,
),
),
)

self.assertTrue(mock_async_openai.called)
self.assertEqual(res.status, "success")
create_params = mock_async_openai.return_value.chat.completions.create_params
self.assertEqual(
create_params["max_completion_tokens"], 1000
)
self.assertEqual(
create_params["max_tokens"], openai.NOT_GIVEN
)

@mock.patch("openai.AsyncOpenAI", return_value=MockAsyncOpenAI)
async def test_non_o1_model_should_use_max_tokens_param(self, mock_async_openai):
context = self.create_mock_context()
state = TurnState()
state.temp = {}
state.conversation = {}
model = OpenAIModel(OpenAIModelOptions(api_key="", default_model="non-o1"))
completion = CompletionConfig(completion_type="chat")
completion.max_tokens = 1000
res = await model.complete_prompt(
context=context,
memory=state,
functions=cast(PromptFunctions, {}),
tokenizer=GPTTokenizer(),
template=PromptTemplate(
name="default",
prompt=Prompt(sections=[TemplateSection("prompt text", "system")]),
config=PromptTemplateConfig(
schema=1.0,
type="completion",
description="test",
completion=completion,
),
),
)

self.assertTrue(mock_async_openai.called)
self.assertEqual(res.status, "success")
create_params = mock_async_openai.return_value.chat.completions.create_params
self.assertEqual(
create_params["max_tokens"], 1000
)
self.assertEqual(
mock_async_openai.return_value.chat.completions.messages[0]["role"], "user"
create_params["max_completion_tokens"], openai.NOT_GIVEN
)

@mock.patch("openai.AsyncOpenAI", return_value=MockAsyncOpenAI)
Expand Down
6 changes: 3 additions & 3 deletions python/samples/04.ai.a.twentyQuestions/src/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,14 @@

if config.OPENAI_KEY:
model = OpenAIModel(
OpenAIModelOptions(api_key=config.OPENAI_KEY, default_model="gpt-3.5-turbo")
OpenAIModelOptions(api_key=config.OPENAI_KEY, default_model="gpt-4o")
)
elif config.AZURE_OPENAI_KEY and config.AZURE_OPENAI_ENDPOINT:
model = OpenAIModel(
AzureOpenAIModelOptions(
api_key=config.AZURE_OPENAI_KEY,
default_model="gpt-35-turbo",
api_version="2023-03-15-preview",
default_model="gpt-4o",
api_version="2024-08-01-preview",
endpoint=config.AZURE_OPENAI_ENDPOINT,
)
)
Expand Down
1 change: 1 addition & 0 deletions python/samples/04.ai.a.twentyQuestions/teamsapp.yml
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,4 @@ deploy:
# You can replace it with your existing Azure Resource id
# or add it to your environment variable file.
resourceId: ${{BOT_AZURE_APP_SERVICE_RESOURCE_ID}}
projectId: 38b5ad68-9f64-41b8-a503-4d3200655664
Loading