Skip to content

Commit

Permalink
Start of long-document-creation extension (microsoft#314)
Browse files Browse the repository at this point in the history
- document model with sections
- tools for LLM to manage document state
- system prompt with inclusion of current document state
- currently exploring o3-mini model
  • Loading branch information
markwaddle authored Feb 5, 2025
1 parent ab6a31b commit dce0c55
Show file tree
Hide file tree
Showing 11 changed files with 998 additions and 59 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
from enum import StrEnum
from pathlib import Path
from typing import Generic, TypeVar

from guided_conversation.guided_conversation_agent import GuidedConversation as GuidedConversationAgent
from openai import AsyncOpenAI
Expand Down Expand Up @@ -38,7 +39,10 @@ class GC_UserDecision(StrEnum):
EXIT_EARLY = "exit_early"


class GuidedConversation:
TArtifactModel = TypeVar("TArtifactModel", bound=BaseModel)


class GuidedConversation(Generic[TArtifactModel]):
"""
An agent for managing artifacts.
"""
Expand All @@ -48,7 +52,7 @@ def __init__(
config: AssistantConfigModel,
openai_client: AsyncOpenAI,
agent_config: GuidedConversationConfigModel,
artifact_model: type[BaseModel],
artifact_model: type[TArtifactModel],
conversation_context: ConversationContext,
artifact_updates: dict = {},
) -> None:
Expand Down Expand Up @@ -113,57 +117,44 @@ async def step_conversation(
# Save the state of the guided conversation agent
_write_guided_conversation_state(self.conversation_context, self.guided_conversation_agent.to_json())

# convert information in artifact for Document Agent
# conversation_status: # this should relate to result.is_conversation_over
# final_response: # replace result.ai_message with final_response if "user_completed"

final_response: str = ""
conversation_status_str: str | None = None
user_decision_str: str | None = None
response: str = ""

# to_json is actually to dict
gc_dict = self.guided_conversation_agent.to_json()
artifact_item = gc_dict.get("artifact")
if artifact_item is not None:
artifact_item = artifact_item.get("artifact")
if artifact_item is not None:
final_response = artifact_item.get("final_response")
conversation_status_str = artifact_item.get("conversation_status")
user_decision_str = artifact_item.get("user_decision")
artifact_item = gc_dict["artifact"]["artifact"]
conversation_status_str: str | None = artifact_item.get("conversation_status")
user_decision_str: str | None = artifact_item.get("user_decision")

response: str = ""
gc_conversation_status = GC_ConversationStatus.UNDEFINED
gc_user_decision: GC_UserDecision = GC_UserDecision.UNDEFINED
if conversation_status_str is not None:
match conversation_status_str:
case GC_ConversationStatus.USER_COMPLETED:
response = final_response or result.ai_message or ""
gc_conversation_status = GC_ConversationStatus.USER_COMPLETED
match user_decision_str:
case GC_UserDecision.UPDATE_OUTLINE:
gc_user_decision = GC_UserDecision.UPDATE_OUTLINE
case GC_UserDecision.DRAFT_PAPER:
gc_user_decision = GC_UserDecision.DRAFT_PAPER
case GC_UserDecision.UPDATE_CONTENT:
gc_user_decision = GC_UserDecision.UPDATE_CONTENT
case GC_UserDecision.DRAFT_NEXT_CONTENT:
gc_user_decision = GC_UserDecision.DRAFT_NEXT_CONTENT
case GC_UserDecision.EXIT_EARLY:
gc_user_decision = GC_UserDecision.EXIT_EARLY

_delete_guided_conversation_state(self.conversation_context)
case GC_ConversationStatus.USER_INITIATED:
if result.ai_message is not None:
response = result.ai_message
else:
response = ""
gc_conversation_status = GC_ConversationStatus.USER_INITIATED
case GC_ConversationStatus.USER_RETURNED:
if result.ai_message is not None:
response = result.ai_message
else:
response = ""
gc_conversation_status = GC_ConversationStatus.USER_RETURNED
gc_user_decision = GC_UserDecision.UNDEFINED

match conversation_status_str:
case GC_ConversationStatus.USER_COMPLETED:
gc_conversation_status = GC_ConversationStatus.USER_COMPLETED
final_response: str | None = artifact_item.get("final_response")
final_response = final_response if final_response != "Unanswered" else ""
response = final_response or result.ai_message or ""

match user_decision_str:
case GC_UserDecision.UPDATE_OUTLINE:
gc_user_decision = GC_UserDecision.UPDATE_OUTLINE
case GC_UserDecision.DRAFT_PAPER:
gc_user_decision = GC_UserDecision.DRAFT_PAPER
case GC_UserDecision.UPDATE_CONTENT:
gc_user_decision = GC_UserDecision.UPDATE_CONTENT
case GC_UserDecision.DRAFT_NEXT_CONTENT:
gc_user_decision = GC_UserDecision.DRAFT_NEXT_CONTENT
case GC_UserDecision.EXIT_EARLY:
gc_user_decision = GC_UserDecision.EXIT_EARLY

_delete_guided_conversation_state(self.conversation_context)

case GC_ConversationStatus.USER_INITIATED:
gc_conversation_status = GC_ConversationStatus.USER_INITIATED
response = result.ai_message or ""

case GC_ConversationStatus.USER_RETURNED:
gc_conversation_status = GC_ConversationStatus.USER_RETURNED
response = result.ai_message or ""

return response, gc_conversation_status, gc_user_decision

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,11 @@ async def _mode_execute(
)
break # ok - get more user input

case StepStatus.USER_EXIT_EARLY:
state.mode_status = ModeStatus.USER_EXIT_EARLY
logger.info("Document Agent: User exited early. Completed.")
break # ok - done early :)

case StepStatus.USER_COMPLETED:
state.mode_status = ModeStatus.USER_COMPLETED

Expand Down Expand Up @@ -202,11 +207,6 @@ def get_next_step(current_step_name: StepName, user_decision: GC_UserDecision) -
)
continue # ok - don't need user input yet

case StepStatus.USER_EXIT_EARLY:
state.mode_status = ModeStatus.USER_EXIT_EARLY
logger.info("Document Agent: User exited early. Completed.")
break # ok - done early :)

return state.mode_status

# endregion
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .extension import ArtifactCreationExtension

__all__ = [
"ArtifactCreationExtension",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
import logging
from abc import ABC, abstractmethod
from typing import Any, AsyncIterator, Awaitable, Callable, Coroutine, Generic, Iterable, TypeVar

from attr import dataclass
from openai import pydantic_function_tool
from openai.types.chat import (
ChatCompletionAssistantMessageParam,
ChatCompletionMessageParam,
ChatCompletionMessageToolCallParam,
ChatCompletionToolMessageParam,
ParsedFunctionToolCall,
)
from pydantic import BaseModel
from semantic_workbench_assistant.assistant_app.context import ConversationContext

from .config import LLMConfig

logger = logging.getLogger(__name__)


class NoResponseChoicesError(Exception):
pass


class NoParsedMessageError(Exception):
pass


ResponseModelT = TypeVar("ResponseModelT", bound=BaseModel)


async def structured_completion(
llm_config: LLMConfig, messages: list[ChatCompletionMessageParam], response_model: type[ResponseModelT]
) -> tuple[ResponseModelT, dict[str, Any]]:
async with llm_config.openai_client_factory() as client:
response = await client.beta.chat.completions.parse(
messages=messages,
model=llm_config.openai_model,
response_format=response_model,
max_tokens=llm_config.max_response_tokens,
)

if not response.choices:
raise NoResponseChoicesError()

if not response.choices[0].message.parsed:
raise NoParsedMessageError()

metadata = {
"request": {
"model": llm_config.openai_model,
"messages": messages,
"max_tokens": llm_config.max_response_tokens,
"response_format": response_model.model_json_schema(),
},
"response": response.model_dump(),
}

return response.choices[0].message.parsed, metadata


class ToolArgsModel(ABC, BaseModel):
@abstractmethod
def set_context(self, context: ConversationContext) -> None: ...


TToolArgs = TypeVar("TToolArgs", bound=ToolArgsModel)
TToolResult = TypeVar("TToolResult", bound=BaseModel)


@dataclass
class CompletionTool(Generic[TToolArgs, TToolResult]):
function: Callable[[TToolArgs], Coroutine[Any, Any, TToolResult]]
argument_model: type[TToolArgs]
description: str = ""
"""Description of the tool. If omitted, wil use the docstring of the function."""


@dataclass
class LLMResponse:
metadata: dict[str, Any]


@dataclass
class ToolCallResponse(LLMResponse):
tool_call: ParsedFunctionToolCall
result: BaseModel


@dataclass
class MessageResponse(LLMResponse):
message: str


async def completion_with_tools(
llm_config: LLMConfig,
context: ConversationContext,
get_messages: Callable[[], Awaitable[Iterable[ChatCompletionMessageParam]]],
tools: list[CompletionTool] = [],
) -> AsyncIterator[ToolCallResponse | MessageResponse]:
openai_tools = [
pydantic_function_tool(
tool.argument_model,
name=tool.function.__name__,
description=tool.description or (tool.function.__doc__ or "").strip(),
)
for tool in tools
]

tool_messages: list[ChatCompletionMessageParam] = []
reasoning_effort = "medium"

async with llm_config.openai_client_factory() as client:
while True:
completion_messages = list(await get_messages()) + tool_messages

response = await client.beta.chat.completions.parse(
messages=completion_messages,
model=llm_config.openai_model,
max_completion_tokens=llm_config.max_response_tokens + 25_000,
tools=openai_tools,
reasoning_effort=reasoning_effort,
# parallel_tool_calls=False,
)

message = response.choices[0].message

if not message.tool_calls:
metadata = {
"request": {
"model": llm_config.openai_model,
"messages": completion_messages,
"tools": openai_tools,
"reasoning_effort": reasoning_effort,
"max_completion_tokens": llm_config.max_response_tokens,
},
"response": response.model_dump(),
}
yield MessageResponse(message=str(message.content), metadata=metadata)
return

async with context.set_status("calling tools..."):
logger.info("tool calls: %s", message.tool_calls)

# append the assistant message with the tool calls for the next iteration
tool_messages.append(
ChatCompletionAssistantMessageParam(
role="assistant",
tool_calls=[
ChatCompletionMessageToolCallParam(
id=c.id,
function={
"name": c.function.name,
"arguments": c.function.arguments,
},
type="function",
)
for c in message.tool_calls
],
)
)
for tool_call in message.tool_calls:
function = tool_call.function

# find the matching tool
tool = next((t for t in tools if t.function.__name__ == function.name), None)
if tool is None:
raise ValueError("Unknown tool call: %s", tool_call.function)

# validate the args and call the tool function
args = tool.argument_model.model_validate(function.parsed_arguments)
args.set_context(context)
result: BaseModel = await tool.function(args)

metadata = {
"request": {
"model": llm_config.openai_model,
"messages": completion_messages,
"tools": openai_tools,
"max_tokens": llm_config.max_response_tokens,
},
"response": response.model_dump(),
"tool_call": tool_call.model_dump(mode="json"),
"tool_result": result.model_dump(mode="json"),
}
yield ToolCallResponse(tool_call=tool_call, result=result, metadata=metadata)

# append the tool result to the messages for the next iteration
tool_messages.append(
ChatCompletionToolMessageParam(
content=result.model_dump_json(),
role="tool",
tool_call_id=tool_call.id,
)
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from dataclasses import dataclass
from typing import Callable

from openai import AsyncOpenAI


@dataclass
class LLMConfig:
openai_client_factory: Callable[[], AsyncOpenAI]
openai_model: str
max_response_tokens: int
Loading

0 comments on commit dce0c55

Please sign in to comment.