Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prospector option to select between form fill and doc creation #244

Merged
merged 8 commits into from
Nov 14, 2024
114 changes: 66 additions & 48 deletions assistants/prospector-assistant/assistant/agents/document_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,7 @@ def _get_step_method(self, step: Step | None) -> Callable | None:
return None
return self._step_name_to_method.get(step.name)

# Not currently used
async def receive_command(
self,
config: AssistantConfigModel,
Expand Down Expand Up @@ -311,7 +312,7 @@ def _set_mode_draft_outline(
self,
config: AssistantConfigModel,
context: ConversationContext,
message: ConversationMessage,
message: ConversationMessage | None,
metadata: dict[str, Any] = {},
) -> None:
# Pre-requisites
Expand All @@ -328,11 +329,12 @@ def _set_mode_draft_outline(
self._state.mode = Mode(name=ModeName.DRAFT_OUTLINE, status=Status.INITIATED)
self._write_state(context)

# Not currently used
def _set_mode_draft_paper(
self,
config: AssistantConfigModel,
context: ConversationContext,
message: ConversationMessage,
message: ConversationMessage | None,
metadata: dict[str, Any] = {},
) -> None:
# Pre-requisites
Expand All @@ -349,11 +351,11 @@ def _set_mode_draft_paper(
self._state.mode = Mode(name=ModeName.DRAFT_PAPER, status=Status.INITIATED)
self._write_state(context)

async def respond_to_conversation(
async def create_document(
self,
config: AssistantConfigModel,
context: ConversationContext,
message: ConversationMessage,
message: ConversationMessage | None,
metadata: dict[str, Any] = {},
) -> bool:
self._state = self._read_state(context)
Expand All @@ -364,15 +366,22 @@ async def respond_to_conversation(
return False

mode = self._state.mode
current_mode_name = mode.get_name()
correct_mode_name = ModeName.DRAFT_OUTLINE # Will update
if not mode.is_running():
self._set_mode_draft_outline(
config, context, message, metadata
) # Will update this mode as implementation expands to full document.
elif current_mode_name is not correct_mode_name:
logger.warning(
"Document Agent must be running in a mode to respond. Current mode: %s and status: %s",
mode.get_name(),
mode.get_status(),
"Document Agent not in correct mode. Returning. Current mode: %s Correct mode: %s",
current_mode_name,
correct_mode_name,
)
return mode.is_running()

# Run
mode = self._state.mode
logger.info("Document Agent in mode %s", mode.get_name())
mode_method = self._get_mode_method(mode)
if mode_method:
Expand Down Expand Up @@ -404,7 +413,7 @@ async def _run_mode(
self,
config: AssistantConfigModel,
context: ConversationContext,
message: ConversationMessage,
message: ConversationMessage | None,
metadata: dict[str, Any] = {},
) -> Status:
# Pre-requisites
Expand Down Expand Up @@ -518,7 +527,7 @@ async def _mode_draft_outline(
self,
config: AssistantConfigModel,
context: ConversationContext,
message: ConversationMessage,
message: ConversationMessage | None,
metadata: dict[str, Any] = {},
) -> Status:
# Pre-requisites
Expand Down Expand Up @@ -576,7 +585,7 @@ async def _mode_draft_paper(
self,
config: AssistantConfigModel,
context: ConversationContext,
message: ConversationMessage,
message: ConversationMessage | None,
metadata: dict[str, Any] = {},
) -> Status:
# Pre-requisites
Expand Down Expand Up @@ -634,7 +643,7 @@ async def _step_gc_attachment_check(
self,
config: AssistantConfigModel,
context: ConversationContext,
message: ConversationMessage,
message: ConversationMessage | None,
metadata: dict[str, Any] = {},
) -> tuple[Status, StepName | None]:
next_step = None
Expand Down Expand Up @@ -675,7 +684,7 @@ async def _step_draft_outline(
self,
config: AssistantConfigModel,
context: ConversationContext,
message: ConversationMessage,
message: ConversationMessage | None,
metadata: dict[str, Any] = {},
) -> tuple[Status, StepName | None]:
next_step = None
Expand Down Expand Up @@ -715,7 +724,7 @@ async def _step_gc_get_outline_feedback(
self,
config: AssistantConfigModel,
context: ConversationContext,
message: ConversationMessage,
message: ConversationMessage | None,
metadata: dict[str, Any] = {},
) -> tuple[Status, StepName | None]:
next_step_name = None
Expand Down Expand Up @@ -763,7 +772,7 @@ async def _step_finish(
self,
config: AssistantConfigModel,
context: ConversationContext,
message: ConversationMessage,
message: ConversationMessage | None,
metadata: dict[str, Any] = {},
) -> tuple[Status, StepName | None]:
# pretend completed
Expand All @@ -773,7 +782,7 @@ async def _step_draft_content(
self,
config: AssistantConfigModel,
context: ConversationContext,
message: ConversationMessage,
message: ConversationMessage | None,
metadata: dict[str, Any] = {},
) -> tuple[Status, StepName | None]:
next_step = None
Expand Down Expand Up @@ -819,7 +828,7 @@ async def _gc_attachment_check(
self,
config: AssistantConfigModel,
context: ConversationContext,
message: ConversationMessage,
message: ConversationMessage | None,
metadata: dict[str, Any] = {},
) -> tuple[Status, StepName | None]:
method_metadata_key = "document_agent_gc_attachment_check"
Expand Down Expand Up @@ -848,8 +857,12 @@ async def _gc_attachment_check(

# run guided conversation step
try:
if message is None:
user_message = None
else:
user_message = message.content
response_message, conversation_status, next_step_name = await guided_conversation.step_conversation(
last_user_message=message.content,
last_user_message=user_message,
)

# add the completion to the metadata for debugging
Expand Down Expand Up @@ -893,16 +906,18 @@ async def _draft_outline(
self,
config: AssistantConfigModel,
context: ConversationContext,
message: ConversationMessage,
message: ConversationMessage | None,
metadata: dict[str, Any] = {},
) -> tuple[Status, StepName | None]:
method_metadata_key = "draft_outline"

# get conversation related info
conversation = await context.get_messages(before=message.id)
if message.message_type == MessageType.chat:
conversation.messages.append(message)
participants_list = await context.get_participants(include_inactive=True)
# get conversation related info -- for now, if no message, assuming no prior conversation
conversation = None
if message is not None:
conversation = await context.get_messages(before=message.id)
if message.message_type == MessageType.chat:
conversation.messages.append(message)
participants_list = await context.get_participants(include_inactive=True)

# get attachments related info
attachment_messages = await self._attachments_extension.get_completion_messages_for_attachments(
Expand All @@ -918,9 +933,10 @@ async def _draft_outline(
# create chat completion messages
chat_completion_messages: list[ChatCompletionMessageParam] = []
chat_completion_messages.append(_draft_outline_main_system_message())
chat_completion_messages.append(
_chat_history_system_message(conversation.messages, participants_list.participants)
)
if conversation is not None:
chat_completion_messages.append(
_chat_history_system_message(conversation.messages, participants_list.participants)
)
chat_completion_messages.extend(attachment_messages)
if outline is not None:
chat_completion_messages.append(_outline_system_message(outline))
Expand Down Expand Up @@ -948,9 +964,9 @@ async def _draft_outline(
# store only latest version for now (will keep all versions later as need arises)
(storage_directory_for_context(context) / "document_agent/outline.txt").write_text(message_content)

# send the response to the conversation only if from a command. Otherwise return info to caller.
# send a command response to the conversation only if from a command. Otherwise return a normal chat message.
message_type = MessageType.chat
if message.message_type == MessageType.command:
if message is not None and message.message_type == MessageType.command:
message_type = MessageType.command

await context.send_messages(
Expand Down Expand Up @@ -978,11 +994,6 @@ async def _gc_outline_feedback(
return Status.UNDEFINED, StepName.UNDEFINED

# Run
if message is not None:
user_message = message.content
else:
user_message = None

gc_outline_feedback_config: GuidedConversationConfigModel = GCDraftOutlineFeedbackConfigModel()

guided_conversation = GuidedConversation(
Expand Down Expand Up @@ -1049,6 +1060,10 @@ async def _gc_outline_feedback(

# run guided conversation step
try:
if message is None:
user_message = None
else:
user_message = message.content
response_message, conversation_status, next_step_name = await guided_conversation.step_conversation(
last_user_message=user_message,
)
Expand Down Expand Up @@ -1094,16 +1109,18 @@ async def _draft_content(
self,
config: AssistantConfigModel,
context: ConversationContext,
message: ConversationMessage,
message: ConversationMessage | None,
metadata: dict[str, Any] = {},
) -> tuple[Status, StepName | None]:
method_metadata_key = "draft_content"

# get conversation related info
conversation = await context.get_messages(before=message.id)
if message.message_type == MessageType.chat:
conversation.messages.append(message)
participants_list = await context.get_participants(include_inactive=True)
# get conversation related info -- for now, if no message, assuming no prior conversation
conversation = None
if message is not None:
conversation = await context.get_messages(before=message.id)
if message.message_type == MessageType.chat:
conversation.messages.append(message)
participants_list = await context.get_participants(include_inactive=True)

# get attachments related info
attachment_messages = await self._attachments_extension.get_completion_messages_for_attachments(
Expand All @@ -1113,9 +1130,10 @@ async def _draft_content(
# create chat completion messages
chat_completion_messages: list[ChatCompletionMessageParam] = []
chat_completion_messages.append(_draft_content_main_system_message())
chat_completion_messages.append(
_chat_history_system_message(conversation.messages, participants_list.participants)
)
if conversation is not None:
chat_completion_messages.append(
_chat_history_system_message(conversation.messages, participants_list.participants)
)
chat_completion_messages.extend(attachment_messages)

# get outline related info
Expand All @@ -1139,29 +1157,29 @@ async def _draft_content(
"response_format": {"type": "text"},
}
completion = await client.chat.completions.create(**completion_args)
content = completion.choices[0].message.content
message_content = completion.choices[0].message.content
_on_success_metadata_update(metadata, method_metadata_key, config, chat_completion_messages, completion)

except Exception as e:
logger.exception(f"exception occurred calling openai chat completion: {e}")
content = (
message_content = (
"An error occurred while calling the OpenAI API. Is it configured correctly?"
"View the debug inspector for more information."
)
_on_error_metadata_update(metadata, method_metadata_key, config, chat_completion_messages, e)

if content is not None:
# store only latest version for now (will keep all versions later as need arises)
(storage_directory_for_context(context) / "document_agent/content.txt").write_text(content)
(storage_directory_for_context(context) / "document_agent/content.txt").write_text(message_content)

# send the response to the conversation only if from a command. Otherwise return info to caller.
# send a command response to the conversation only if from a command. Otherwise return a normal chat message.
message_type = MessageType.chat
if message.message_type == MessageType.command:
if message is not None and message.message_type == MessageType.command:
message_type = MessageType.command

await context.send_messages(
NewConversationMessage(
content=content,
content=message_content,
message_type=message_type,
metadata=metadata,
)
Expand Down
Loading