diff --git a/assistants/prospector-assistant/assistant/agents/document/guided_conversation.py b/assistants/prospector-assistant/assistant/agents/document/guided_conversation.py index 2ecc077f..01a2a1b9 100644 --- a/assistants/prospector-assistant/assistant/agents/document/guided_conversation.py +++ b/assistants/prospector-assistant/assistant/agents/document/guided_conversation.py @@ -2,6 +2,7 @@ import logging from enum import StrEnum from pathlib import Path +from typing import Generic, TypeVar from guided_conversation.guided_conversation_agent import GuidedConversation as GuidedConversationAgent from openai import AsyncOpenAI @@ -38,7 +39,10 @@ class GC_UserDecision(StrEnum): EXIT_EARLY = "exit_early" -class GuidedConversation: +TArtifactModel = TypeVar("TArtifactModel", bound=BaseModel) + + +class GuidedConversation(Generic[TArtifactModel]): """ An agent for managing artifacts. """ @@ -48,7 +52,7 @@ def __init__( config: AssistantConfigModel, openai_client: AsyncOpenAI, agent_config: GuidedConversationConfigModel, - artifact_model: type[BaseModel], + artifact_model: type[TArtifactModel], conversation_context: ConversationContext, artifact_updates: dict = {}, ) -> None: @@ -113,57 +117,44 @@ async def step_conversation( # Save the state of the guided conversation agent _write_guided_conversation_state(self.conversation_context, self.guided_conversation_agent.to_json()) - # convert information in artifact for Document Agent - # conversation_status: # this should relate to result.is_conversation_over - # final_response: # replace result.ai_message with final_response if "user_completed" - - final_response: str = "" - conversation_status_str: str | None = None - user_decision_str: str | None = None - response: str = "" - # to_json is actually to dict gc_dict = self.guided_conversation_agent.to_json() - artifact_item = gc_dict.get("artifact") - if artifact_item is not None: - artifact_item = artifact_item.get("artifact") - if artifact_item is not None: - final_response = artifact_item.get("final_response") - conversation_status_str = artifact_item.get("conversation_status") - user_decision_str = artifact_item.get("user_decision") + artifact_item = gc_dict["artifact"]["artifact"] + conversation_status_str: str | None = artifact_item.get("conversation_status") + user_decision_str: str | None = artifact_item.get("user_decision") + response: str = "" gc_conversation_status = GC_ConversationStatus.UNDEFINED - gc_user_decision: GC_UserDecision = GC_UserDecision.UNDEFINED - if conversation_status_str is not None: - match conversation_status_str: - case GC_ConversationStatus.USER_COMPLETED: - response = final_response or result.ai_message or "" - gc_conversation_status = GC_ConversationStatus.USER_COMPLETED - match user_decision_str: - case GC_UserDecision.UPDATE_OUTLINE: - gc_user_decision = GC_UserDecision.UPDATE_OUTLINE - case GC_UserDecision.DRAFT_PAPER: - gc_user_decision = GC_UserDecision.DRAFT_PAPER - case GC_UserDecision.UPDATE_CONTENT: - gc_user_decision = GC_UserDecision.UPDATE_CONTENT - case GC_UserDecision.DRAFT_NEXT_CONTENT: - gc_user_decision = GC_UserDecision.DRAFT_NEXT_CONTENT - case GC_UserDecision.EXIT_EARLY: - gc_user_decision = GC_UserDecision.EXIT_EARLY - - _delete_guided_conversation_state(self.conversation_context) - case GC_ConversationStatus.USER_INITIATED: - if result.ai_message is not None: - response = result.ai_message - else: - response = "" - gc_conversation_status = GC_ConversationStatus.USER_INITIATED - case GC_ConversationStatus.USER_RETURNED: - if result.ai_message is not None: - response = result.ai_message - else: - response = "" - gc_conversation_status = GC_ConversationStatus.USER_RETURNED + gc_user_decision = GC_UserDecision.UNDEFINED + + match conversation_status_str: + case GC_ConversationStatus.USER_COMPLETED: + gc_conversation_status = GC_ConversationStatus.USER_COMPLETED + final_response: str | None = artifact_item.get("final_response") + final_response = final_response if final_response != "Unanswered" else "" + response = final_response or result.ai_message or "" + + match user_decision_str: + case GC_UserDecision.UPDATE_OUTLINE: + gc_user_decision = GC_UserDecision.UPDATE_OUTLINE + case GC_UserDecision.DRAFT_PAPER: + gc_user_decision = GC_UserDecision.DRAFT_PAPER + case GC_UserDecision.UPDATE_CONTENT: + gc_user_decision = GC_UserDecision.UPDATE_CONTENT + case GC_UserDecision.DRAFT_NEXT_CONTENT: + gc_user_decision = GC_UserDecision.DRAFT_NEXT_CONTENT + case GC_UserDecision.EXIT_EARLY: + gc_user_decision = GC_UserDecision.EXIT_EARLY + + _delete_guided_conversation_state(self.conversation_context) + + case GC_ConversationStatus.USER_INITIATED: + gc_conversation_status = GC_ConversationStatus.USER_INITIATED + response = result.ai_message or "" + + case GC_ConversationStatus.USER_RETURNED: + gc_conversation_status = GC_ConversationStatus.USER_RETURNED + response = result.ai_message or "" return response, gc_conversation_status, gc_user_decision diff --git a/assistants/prospector-assistant/assistant/agents/document_agent.py b/assistants/prospector-assistant/assistant/agents/document_agent.py index ac670252..23abcabd 100644 --- a/assistants/prospector-assistant/assistant/agents/document_agent.py +++ b/assistants/prospector-assistant/assistant/agents/document_agent.py @@ -158,6 +158,11 @@ async def _mode_execute( ) break # ok - get more user input + case StepStatus.USER_EXIT_EARLY: + state.mode_status = ModeStatus.USER_EXIT_EARLY + logger.info("Document Agent: User exited early. Completed.") + break # ok - done early :) + case StepStatus.USER_COMPLETED: state.mode_status = ModeStatus.USER_COMPLETED @@ -202,11 +207,6 @@ def get_next_step(current_step_name: StepName, user_decision: GC_UserDecision) - ) continue # ok - don't need user input yet - case StepStatus.USER_EXIT_EARLY: - state.mode_status = ModeStatus.USER_EXIT_EARLY - logger.info("Document Agent: User exited early. Completed.") - break # ok - done early :) - return state.mode_status # endregion diff --git a/assistants/prospector-assistant/assistant/artifact_creation_extension/__init__.py b/assistants/prospector-assistant/assistant/artifact_creation_extension/__init__.py new file mode 100644 index 00000000..cd4135ec --- /dev/null +++ b/assistants/prospector-assistant/assistant/artifact_creation_extension/__init__.py @@ -0,0 +1,5 @@ +from .extension import ArtifactCreationExtension + +__all__ = [ + "ArtifactCreationExtension", +] diff --git a/assistants/prospector-assistant/assistant/artifact_creation_extension/_llm.py b/assistants/prospector-assistant/assistant/artifact_creation_extension/_llm.py new file mode 100644 index 00000000..7621f8d9 --- /dev/null +++ b/assistants/prospector-assistant/assistant/artifact_creation_extension/_llm.py @@ -0,0 +1,196 @@ +import logging +from abc import ABC, abstractmethod +from typing import Any, AsyncIterator, Awaitable, Callable, Coroutine, Generic, Iterable, TypeVar + +from attr import dataclass +from openai import pydantic_function_tool +from openai.types.chat import ( + ChatCompletionAssistantMessageParam, + ChatCompletionMessageParam, + ChatCompletionMessageToolCallParam, + ChatCompletionToolMessageParam, + ParsedFunctionToolCall, +) +from pydantic import BaseModel +from semantic_workbench_assistant.assistant_app.context import ConversationContext + +from .config import LLMConfig + +logger = logging.getLogger(__name__) + + +class NoResponseChoicesError(Exception): + pass + + +class NoParsedMessageError(Exception): + pass + + +ResponseModelT = TypeVar("ResponseModelT", bound=BaseModel) + + +async def structured_completion( + llm_config: LLMConfig, messages: list[ChatCompletionMessageParam], response_model: type[ResponseModelT] +) -> tuple[ResponseModelT, dict[str, Any]]: + async with llm_config.openai_client_factory() as client: + response = await client.beta.chat.completions.parse( + messages=messages, + model=llm_config.openai_model, + response_format=response_model, + max_tokens=llm_config.max_response_tokens, + ) + + if not response.choices: + raise NoResponseChoicesError() + + if not response.choices[0].message.parsed: + raise NoParsedMessageError() + + metadata = { + "request": { + "model": llm_config.openai_model, + "messages": messages, + "max_tokens": llm_config.max_response_tokens, + "response_format": response_model.model_json_schema(), + }, + "response": response.model_dump(), + } + + return response.choices[0].message.parsed, metadata + + +class ToolArgsModel(ABC, BaseModel): + @abstractmethod + def set_context(self, context: ConversationContext) -> None: ... + + +TToolArgs = TypeVar("TToolArgs", bound=ToolArgsModel) +TToolResult = TypeVar("TToolResult", bound=BaseModel) + + +@dataclass +class CompletionTool(Generic[TToolArgs, TToolResult]): + function: Callable[[TToolArgs], Coroutine[Any, Any, TToolResult]] + argument_model: type[TToolArgs] + description: str = "" + """Description of the tool. If omitted, wil use the docstring of the function.""" + + +@dataclass +class LLMResponse: + metadata: dict[str, Any] + + +@dataclass +class ToolCallResponse(LLMResponse): + tool_call: ParsedFunctionToolCall + result: BaseModel + + +@dataclass +class MessageResponse(LLMResponse): + message: str + + +async def completion_with_tools( + llm_config: LLMConfig, + context: ConversationContext, + get_messages: Callable[[], Awaitable[Iterable[ChatCompletionMessageParam]]], + tools: list[CompletionTool] = [], +) -> AsyncIterator[ToolCallResponse | MessageResponse]: + openai_tools = [ + pydantic_function_tool( + tool.argument_model, + name=tool.function.__name__, + description=tool.description or (tool.function.__doc__ or "").strip(), + ) + for tool in tools + ] + + tool_messages: list[ChatCompletionMessageParam] = [] + reasoning_effort = "medium" + + async with llm_config.openai_client_factory() as client: + while True: + completion_messages = list(await get_messages()) + tool_messages + + response = await client.beta.chat.completions.parse( + messages=completion_messages, + model=llm_config.openai_model, + max_completion_tokens=llm_config.max_response_tokens + 25_000, + tools=openai_tools, + reasoning_effort=reasoning_effort, + # parallel_tool_calls=False, + ) + + message = response.choices[0].message + + if not message.tool_calls: + metadata = { + "request": { + "model": llm_config.openai_model, + "messages": completion_messages, + "tools": openai_tools, + "reasoning_effort": reasoning_effort, + "max_completion_tokens": llm_config.max_response_tokens, + }, + "response": response.model_dump(), + } + yield MessageResponse(message=str(message.content), metadata=metadata) + return + + async with context.set_status("calling tools..."): + logger.info("tool calls: %s", message.tool_calls) + + # append the assistant message with the tool calls for the next iteration + tool_messages.append( + ChatCompletionAssistantMessageParam( + role="assistant", + tool_calls=[ + ChatCompletionMessageToolCallParam( + id=c.id, + function={ + "name": c.function.name, + "arguments": c.function.arguments, + }, + type="function", + ) + for c in message.tool_calls + ], + ) + ) + for tool_call in message.tool_calls: + function = tool_call.function + + # find the matching tool + tool = next((t for t in tools if t.function.__name__ == function.name), None) + if tool is None: + raise ValueError("Unknown tool call: %s", tool_call.function) + + # validate the args and call the tool function + args = tool.argument_model.model_validate(function.parsed_arguments) + args.set_context(context) + result: BaseModel = await tool.function(args) + + metadata = { + "request": { + "model": llm_config.openai_model, + "messages": completion_messages, + "tools": openai_tools, + "max_tokens": llm_config.max_response_tokens, + }, + "response": response.model_dump(), + "tool_call": tool_call.model_dump(mode="json"), + "tool_result": result.model_dump(mode="json"), + } + yield ToolCallResponse(tool_call=tool_call, result=result, metadata=metadata) + + # append the tool result to the messages for the next iteration + tool_messages.append( + ChatCompletionToolMessageParam( + content=result.model_dump_json(), + role="tool", + tool_call_id=tool_call.id, + ) + ) diff --git a/assistants/prospector-assistant/assistant/artifact_creation_extension/config.py b/assistants/prospector-assistant/assistant/artifact_creation_extension/config.py new file mode 100644 index 00000000..7956fc3a --- /dev/null +++ b/assistants/prospector-assistant/assistant/artifact_creation_extension/config.py @@ -0,0 +1,11 @@ +from dataclasses import dataclass +from typing import Callable + +from openai import AsyncOpenAI + + +@dataclass +class LLMConfig: + openai_client_factory: Callable[[], AsyncOpenAI] + openai_model: str + max_response_tokens: int diff --git a/assistants/prospector-assistant/assistant/artifact_creation_extension/document.py b/assistants/prospector-assistant/assistant/artifact_creation_extension/document.py new file mode 100644 index 00000000..698f38d4 --- /dev/null +++ b/assistants/prospector-assistant/assistant/artifact_creation_extension/document.py @@ -0,0 +1,79 @@ +import uuid +from datetime import datetime, timezone + +from pydantic import BaseModel, Field + + +class SectionMetadata(BaseModel): + purpose: str = "" + """Describes the intent of the section.""" + created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + """Timestamp for when the section was created.""" + last_modified_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + """Timestamp for the last modification.""" + + +class Section(BaseModel): + """ + Represents a section in a document, with a heading level, section number, title and content. + + Sections are the basic building blocks of a document. They are ordered within a document. They + have a heading level of 1-N. + """ + + heading_level: int + """The level of the section in the hierarchy. Top-level sections are level 1, and nested sections are level 2 and beyond.""" + section_number: str + """The number of the section in a heirarchical format. For example, 1.1.1. Section numbers are unique within the document.""" + + title: str + """The title of the section.""" + content: str = "" + """Content of the section, supporting Markdown for formatting.""" + + metadata: SectionMetadata = SectionMetadata() + """Metadata describing the section.""" + + +class DocumentMetadata(BaseModel): + """ + Metadata for a document, including title, purpose, audience, version, author, contributors, + and timestamps for creation and last modification. + """ + + document_id: str = Field(default_factory=lambda: uuid.uuid4().hex[0:8]) + + purpose: str = "" + """Describes the intent of the document""" + audience: str = "" + """Describes the intended audience for the document""" + other_guidelines: str = "" + """ + Describes any other guidelines or standards, stylistic, structure, reading level, etc., + that the document should follow + """ + supporting_documents: list[str] = Field(default_factory=list) + """List of document titles for supporting documents.""" + created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + """Timestamp for when the document was created.""" + last_modified_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + """Timestamp for the last modification.""" + + +class Document(BaseModel): + """ + Represents a complete document, including metadata, sections, and references to supporting documents. + """ + + title: str = "" + """Title of the document. Doubles as a unique identifier for the document.""" + sections: list[Section] = Field(default_factory=list) + """Structured content of the document.""" + + metadata: DocumentMetadata = DocumentMetadata() + """Metadata describing the document.""" + + +class DocumentHeader(BaseModel): + document_id: str + title: str diff --git a/assistants/prospector-assistant/assistant/artifact_creation_extension/extension.py b/assistants/prospector-assistant/assistant/artifact_creation_extension/extension.py new file mode 100644 index 00000000..b50e8778 --- /dev/null +++ b/assistants/prospector-assistant/assistant/artifact_creation_extension/extension.py @@ -0,0 +1,253 @@ +import logging +from textwrap import dedent + +import openai_client +from openai.types.chat import ( + ChatCompletionAssistantMessageParam, + ChatCompletionDeveloperMessageParam, + ChatCompletionMessageParam, + ChatCompletionUserMessageParam, +) +from semantic_workbench_api_model.workbench_model import ( + ConversationEvent, + ConversationMessage, + MessageType, + NewConversationMessage, + ParticipantRole, +) +from semantic_workbench_assistant.assistant_app.config import BaseModelAssistantConfig +from semantic_workbench_assistant.assistant_app.context import ConversationContext +from semantic_workbench_assistant.assistant_app.protocol import AssistantAppProtocol + +from ..config import AssistantConfigModel +from . import store, tools +from ._llm import CompletionTool, MessageResponse, ToolCallResponse, completion_with_tools +from .config import LLMConfig +from .store import ActiveDocumentInspector, DocumentWorkspaceInspector + +logger = logging.getLogger(__name__) + + +system_message = dedent( + """ + You are an assistant. Ultimately, you help users create documents in a document workspace. To do this, you + will assist with ideation, drafting, and editing. Documents are can represent a variety of content types, + such as reports, articles, blog posts, stories, slide decks or others. You can create, update, and remove + documents, as well as create, update, and remove sections within documents. You can also list all documents + in the workspace. + The documents in the workspace, in their current state, are available here for your reference. + + Documents: + --- + """ +) + + +class ArtifactCreationExtension: + def __init__( + self, assistant_app: AssistantAppProtocol, assistant_config: BaseModelAssistantConfig[AssistantConfigModel] + ) -> None: + document_workspace_inspector = DocumentWorkspaceInspector() + active_document_inspector = ActiveDocumentInspector() + assistant_app.add_inspector_state_provider( + document_workspace_inspector.display_name, document_workspace_inspector + ) + assistant_app.add_inspector_state_provider(active_document_inspector.display_name, active_document_inspector) + + @assistant_app.events.conversation.message.command.on_created + async def on_message_command_created( + context: ConversationContext, _: ConversationEvent, message: ConversationMessage + ) -> None: + config = await assistant_config.get(context.assistant) + if config.guided_workflow != "Long Document Creation": + return + + match message.content.split(" ")[0]: + case "/help": + await _send_message( + context, + dedent(""" + /help - Display this help message. + /ls - List all documents in the workspace. + /select - Select the active document. + """), + {}, + message_type=MessageType.command_response, + ) + + case "/ls": + args = tools.ListDocumentsArgs() + args.set_context(context) + headers = await tools.list_documents(args) + document_list = "\n".join( + f"{index}. {header.title} ({header.document_id})" + for index, header in enumerate(headers.documents) + ) + await _send_message( + context, + f"Documents in the workspace: {headers.count}\n\n{document_list}", + {}, + message_type=MessageType.command_response, + ) + + case "/select": + index = int(message.content.split(" ")[1]) + list_args = tools.ListDocumentsArgs() + list_args.set_context(context) + headers = await tools.list_documents(list_args) + store.active_document_id = headers.documents[index].document_id + await _send_message( + context, + f"Selected document: {headers.documents[index].title}", + {}, + message_type=MessageType.command_response, + ) + + case _: + await _send_message( + context, + "Unknown command. Use /help to see available commands.", + {}, + message_type=MessageType.command_response, + ) + + @assistant_app.events.conversation.message.chat.on_created + async def on_message_chat_created( + context: ConversationContext, _: ConversationEvent, message: ConversationMessage + ) -> None: + config = await assistant_config.get(context.assistant) + if config.guided_workflow != "Long Document Creation": + return + + async with context.set_status("responding ..."): + messages_response = await context.get_messages(before=message.id) + + async def message_generator() -> list[ChatCompletionMessageParam]: + messages: list[ChatCompletionMessageParam] = [] + for msg in (*messages_response.messages, message): + match msg.sender.participant_role: + case ParticipantRole.user: + messages.append(ChatCompletionUserMessageParam(content=msg.content, role="user")) + + case ParticipantRole.assistant: + messages.append( + ChatCompletionAssistantMessageParam(content=msg.content, role="assistant") + ) + + list_args = tools.ListDocumentsArgs() + list_args.set_context(context) + headers = await tools.list_documents(list_args) + document_content_list = "" + if not headers: + document_content_list = "There are currently documents in the workspace." + + for header in headers.documents: + get_document_args = tools.GetDocumentArgs(document_id=header.document_id) + get_document_args.set_context(context) + document = await tools.get_document(get_document_args) + + document_content_list += f"\n\n```json\n{document.model_dump_json()}\n```" + + messages.append( + # ChatCompletionSystemMessageParam(content=system_message + document_content_list, role="system") + ChatCompletionDeveloperMessageParam( + content=system_message + document_content_list, role="developer" + ) + ) + return messages + + completion_tools = [ + CompletionTool( + function=tools.create_document, + argument_model=tools.CreateDocumentArgs, + ), + CompletionTool( + function=tools.update_document, + argument_model=tools.UpdateDocumentArgs, + ), + CompletionTool( + function=tools.remove_document, + argument_model=tools.RemoveDocumentArgs, + ), + CompletionTool( + function=tools.get_document, + argument_model=tools.GetDocumentArgs, + ), + CompletionTool( + function=tools.create_document_section, + argument_model=tools.CreateDocumentSectionArgs, + ), + CompletionTool( + function=tools.update_document_section, + argument_model=tools.UpdateDocumentSectionArgs, + ), + CompletionTool( + function=tools.remove_document_section, + argument_model=tools.RemoveDocumentSectionArgs, + ), + CompletionTool( + function=tools.list_documents, + argument_model=tools.ListDocumentsArgs, + ), + ] + + config = await assistant_config.get(context.assistant) + config.service_config.azure_openai_deployment = "o3-mini" # type: ignore + llm_config = LLMConfig( + openai_client_factory=lambda: openai_client.create_client(config.service_config), + openai_model="o3-mini", # config.request_config.openai_model, + max_response_tokens=config.request_config.response_tokens, + ) + + try: + async for response in completion_with_tools( + llm_config=llm_config, + context=context, + get_messages=lambda: message_generator(), + tools=completion_tools, + ): + match response: + case MessageResponse(): + await _send_message(context, response.message, response.metadata) + + case ToolCallResponse(): + async with ( + context.state_updated_event_after(document_workspace_inspector.display_name), + context.state_updated_event_after(active_document_inspector.display_name), + ): + await _send_message( + context, + f"Called {response.tool_call.function.name}", + response.metadata, + message_type=MessageType.notice, + ) + + except Exception as e: + logger.exception("Failed to generate completion.") + await _send_error_message(context, "Failed to generate completion.", {"error": str(e)}) + return + + +async def _send_message( + context: ConversationContext, message: str, debug: dict, message_type: MessageType = MessageType.chat +) -> None: + if not message: + return + + await context.send_messages( + NewConversationMessage( + content=message, + message_type=message_type, + debug_data=debug, + ) + ) + + +async def _send_error_message(context: ConversationContext, message: str, debug: dict) -> None: + await context.send_messages( + NewConversationMessage( + content=message, + message_type=MessageType.notice, + debug_data=debug, + ) + ) diff --git a/assistants/prospector-assistant/assistant/artifact_creation_extension/store.py b/assistants/prospector-assistant/assistant/artifact_creation_extension/store.py new file mode 100644 index 00000000..3d115b69 --- /dev/null +++ b/assistants/prospector-assistant/assistant/artifact_creation_extension/store.py @@ -0,0 +1,123 @@ +from contextlib import contextmanager +from pathlib import Path +from typing import Iterator + +import yaml +from semantic_workbench_assistant.assistant_app.context import ConversationContext, storage_directory_for_context +from semantic_workbench_assistant.assistant_app.protocol import ( + AssistantConversationInspectorStateDataModel, + ReadOnlyAssistantConversationInspectorStateProvider, +) + +from .document import Document, DocumentHeader + + +class DocumentStore: + def __init__(self, store_path: Path): + store_path.mkdir(parents=True, exist_ok=True) + self.store_path = store_path + + def _path_for(self, id: str) -> Path: + return self.store_path / f"{id}.json" + + def write(self, document: Document) -> None: + path = self._path_for(document.metadata.document_id) + path.write_text(document.model_dump_json(indent=2)) + + def read(self, id: str) -> Document: + path = self._path_for(id) + try: + return Document.model_validate_json(path.read_text()) + except FileNotFoundError: + raise ValueError(f"Document not found: {id}") + + @contextmanager + def checkout(self, id: str) -> Iterator[Document]: + document = self.read(id=id) + yield document + self.write(document) + + def delete(self, id: str) -> None: + path = self._path_for(id) + path.unlink(missing_ok=True) + + def list_documents(self) -> list[DocumentHeader]: + documents = [] + for path in self.store_path.glob("*.json"): + document = Document.model_validate_json(path.read_text()) + documents.append(DocumentHeader(document_id=document.metadata.document_id, title=document.title)) + + return sorted(documents, key=lambda document: document.title.lower()) + + +def for_context(context: ConversationContext) -> DocumentStore: + doc_store_root = storage_directory_for_context(context) / "document_store" + return DocumentStore(doc_store_root) + + +def project_to_yaml(state: dict | list[dict]) -> str: + """ + Project the state to a yaml code block. + """ + state_as_yaml = yaml.dump(state, sort_keys=False) + return f"```yaml\n{state_as_yaml}\n```" + + +class DocumentWorkspaceInspector(ReadOnlyAssistantConversationInspectorStateProvider): + @property + def display_name(self) -> str: + return "Document Workspace" + + @property + def description(self) -> str: + return "Documents in the workspace." + + async def get(self, context: ConversationContext) -> AssistantConversationInspectorStateDataModel: + store = for_context(context) + documents: list[dict] = [] + for header in store.list_documents(): + doc = store.read(header.document_id) + documents.append(doc.model_dump(mode="json")) + projected = project_to_yaml(documents) + return AssistantConversationInspectorStateDataModel(data={"content": projected}) + + +active_document_id: str | None = None + + +class ActiveDocumentInspector(ReadOnlyAssistantConversationInspectorStateProvider): + @property + def display_name(self) -> str: + return "Active Document" + + @property + def description(self) -> str: + return "The active document." + + async def get(self, context: ConversationContext) -> AssistantConversationInspectorStateDataModel: + global active_document_id + store = for_context(context) + headers = store.list_documents() + if not headers: + return AssistantConversationInspectorStateDataModel(data={"content": "No active document."}) + + if active_document_id is None: + active_document_id = headers[0].document_id + + doc = store.read(active_document_id) + + projected = project_document_to_markdown(doc) + + return AssistantConversationInspectorStateDataModel(data={"content": projected}) + + +def project_document_to_markdown(doc: Document) -> str: + """ + Project the document to a markdown code block. + """ + markdown = f"# {doc.title}\n\n***{doc.metadata.purpose}***\n\n" + for section in doc.sections: + markdown += f"{'#' * section.heading_level} {section.section_number} {section.title}\n\n{section.content}\n\n" + markdown += "-" * 3 + "\n\n" + + return f"```markdown\n{markdown}\n```" diff --git a/assistants/prospector-assistant/assistant/artifact_creation_extension/tools.py b/assistants/prospector-assistant/assistant/artifact_creation_extension/tools.py new file mode 100644 index 00000000..83112cf1 --- /dev/null +++ b/assistants/prospector-assistant/assistant/artifact_creation_extension/tools.py @@ -0,0 +1,272 @@ +from collections import defaultdict +from datetime import datetime, timezone +from typing import Optional + +from pydantic import BaseModel, Field +from semantic_workbench_assistant.assistant_app.context import ConversationContext + +from ._llm import ToolArgsModel +from .document import Document, DocumentHeader, DocumentMetadata, Section, SectionMetadata +from .store import DocumentStore, for_context + + +class ArgsWithDocumentStore(ToolArgsModel): + def set_context(self, context: ConversationContext) -> None: + self._context = context + + @property + def store(self) -> DocumentStore: + return for_context(self._context) + + +class CreateDocumentArgs(ArgsWithDocumentStore): + title: str = Field(description="Document title") + purpose: Optional[str] = Field(description="Describes the intent of the document.") + audience: Optional[str] = Field(description="Describes the intended audience for the document.") + other_guidelines: Optional[str] = Field( + description="Describes any other guidelines or standards that the document should follow." + ) + + +async def create_document(args: CreateDocumentArgs) -> DocumentMetadata: + """ + Create a new document with the specified metadata. + """ + metadata = DocumentMetadata() + if args.purpose is not None: + metadata.purpose = args.purpose + if args.audience is not None: + metadata.audience = args.audience + if args.other_guidelines is not None: + metadata.other_guidelines = args.other_guidelines + document = Document(title=args.title, metadata=metadata) + + args.store.write(document) + + return document.metadata + + +class UpdateDocumentArgs(ArgsWithDocumentStore): + document_id: str = Field(description="The id of the document to update.") + title: Optional[str] = Field(description="The updated title of the document. Pass None to leave unchanged.") + purpose: Optional[str] = Field( + description="Describes the intent of the document. Can be left blank. Pass None to leave unchanged." + ) + audience: Optional[str] = Field( + description="Describes the intended audience for the document. Can be left blank. Pass None to leave unchanged." + ) + other_guidelines: Optional[str] = Field( + description="Describes any other guidelines or standards that the document should follow. Can be left blank. Pass None to leave unchanged." + ) + + +async def update_document(args: UpdateDocumentArgs) -> DocumentMetadata: + """ + Update the metadata of an existing document. + """ + with args.store.checkout(args.document_id) as document: + if args.title is not None: + document.title = args.title + if args.purpose is not None: + document.metadata.purpose = args.purpose + if args.audience is not None: + document.metadata.audience = args.audience + if args.other_guidelines is not None: + document.metadata.other_guidelines = args.other_guidelines + + document.metadata.last_modified_at = datetime.now(timezone.utc) + + return document.metadata + + +class GetDocumentArgs(ArgsWithDocumentStore): + document_id: str = Field(description="The id of the document to retrieve.") + + +async def get_document(args: GetDocumentArgs) -> Document: + """ + Retrieve a document by its id. + """ + return args.store.read(id=args.document_id) + + +class RemoveDocumentArgs(ArgsWithDocumentStore): + document_id: str = Field(description="The id of the document to remove.") + + +async def remove_document(args: RemoveDocumentArgs) -> Document: + """ + Remove a document from the workspace. + """ + document = args.store.read(id=args.document_id) + args.store.delete(id=args.document_id) + return document + + +class CreateDocumentSectionArgs(ArgsWithDocumentStore): + document_id: str = Field(description="The id of the document to add the section to.") + insert_before_section_number: Optional[str] = Field( + description="The section number of the section to insert the new section ***before***." + " Pass None to insert at the end of the document, after all existing sections, if any." + " For example, if there are sections '1', '2', and '3', and you want to insert a section" + " between '2' and '3'. Then the insert_before_section_number should be '3'.", + ) + section_heading_level: int = Field(description="The heading level of the new section.") + section_title: str = Field(description="The title of the new section.") + section_purpose: Optional[str] = Field(description="Describes the intent of the new section.") + section_content: str = Field(description="The content of the new section. Can be left blank.") + + +async def create_document_section(args: CreateDocumentSectionArgs) -> Section: + """ + Create a new section in an existing document. + """ + + with args.store.checkout(args.document_id) as document: + document.metadata.last_modified_at = datetime.now(timezone.utc) + + metadata = SectionMetadata() + if args.section_purpose is not None: + metadata.purpose = args.section_purpose + + heading_level = args.section_heading_level + insert_at_index = len(document.sections) + if args.insert_before_section_number is not None: + _, insert_at_index = _find_section(args.insert_before_section_number, document) + if insert_at_index == -1: + raise ValueError( + f"Section {args.insert_before_section_number} not found in document {args.document_id}" + ) + + section = Section( + title=args.section_title, + content=args.section_content, + metadata=metadata, + section_number="will be renumbered", + heading_level=heading_level, + ) + + document.sections.insert(insert_at_index, section) + + _renumber_sections(document.sections) + + return section + + +class UpdateDocumentSectionArgs(ArgsWithDocumentStore): + document_id: str = Field(description="The id of the document containing the section to update.") + section_number: str = Field(description="The number of the section to update.") + section_heading_level: Optional[int] = Field( + description="The updated heading level of the section. Pass None to leave unchanged." + ) + section_title: Optional[str] = Field(description="The updated title of the section. Pass None to leave unchanged.") + section_purpose: Optional[str] = Field( + description="The updated purpose of the new section. Pass None to leave unchanged." + ) + section_content: Optional[str] = Field( + description="The updated content of the section. Pass None to leave unchanged." + ) + + +async def update_document_section(args: UpdateDocumentSectionArgs) -> Section: + """ + Update the content of a section in an existing document. + """ + with args.store.checkout(args.document_id) as document: + section, _ = _find_section(args.section_number, document) + if section is None: + raise ValueError(f"Section {args.section_number} not found in document {args.document_id}") + + if args.section_heading_level is not None: + section.heading_level = args.section_heading_level + if args.section_title is not None: + section.title = args.section_title + if args.section_purpose is not None: + section.metadata.purpose = args.section_purpose + if args.section_content is not None: + section.content = args.section_content + + document.metadata.last_modified_at = datetime.now(timezone.utc) + _renumber_sections(document.sections) + + return section + + +class RemoveDocumentSectionArgs(ArgsWithDocumentStore): + document_id: str = Field(description="The id of the document containing the section to remove.") + section_number: str = Field(description="The section number of the section to remove.") + + +async def remove_document_section(args: RemoveDocumentSectionArgs) -> Section: + """ + Remove a section from an existing document. Note that removing a section will also remove all nested sections. + """ + with args.store.checkout(args.document_id) as document: + section, _ = _find_section(args.section_number, document) + if section is None: + raise ValueError(f"Section with number {args.section_number} not found in document {args.document_id}") + + document.sections.remove(section) + + _renumber_sections(document.sections) + + document.metadata.last_modified_at = datetime.now(timezone.utc) + + return section + + +class DocumentList(BaseModel): + documents: list[DocumentHeader] + count: int = Field(description="The number of documents in the workspace.") + + +class ListDocumentsArgs(ArgsWithDocumentStore): + pass + + +async def list_documents(args: ListDocumentsArgs) -> DocumentList: + """ + List the titles of all documents in the workspace. + """ + headers = args.store.list_documents() + return DocumentList(documents=headers, count=len(headers)) + + +def _find_section(section_number: str, document: Document) -> tuple[Section | None, int]: + section, index = next( + ( + (section, index) + for index, section in enumerate(document.sections) + if section.section_number == section_number + ), + (None, -1), + ) + return section, index + + +def _renumber_sections(sections: list[Section]) -> None: + """ + Renumber the sections in the list. + """ + current_heading_level = -1 + sections_at_level = defaultdict(lambda: 0) + current_section_number_parts: list[str] = [] + + for section in sections: + if section.heading_level == current_heading_level: + sections_at_level[section.heading_level] += 1 + current_section_number_parts.pop() + + if section.heading_level > current_heading_level: + current_heading_level = section.heading_level + sections_at_level[section.heading_level] = 1 + + if section.heading_level < current_heading_level: + for i in range(current_heading_level - section.heading_level): + sections_at_level.pop(current_heading_level + i, 0) + current_heading_level = section.heading_level + sections_at_level[section.heading_level] += 1 + current_section_number_parts = current_section_number_parts[: section.heading_level - 1] + + current_section_number_parts.append(str(sections_at_level[current_heading_level])) + section.section_number = ".".join(current_section_number_parts) diff --git a/assistants/prospector-assistant/assistant/chat.py b/assistants/prospector-assistant/assistant/chat.py index 3fc3d40b..7fa9b058 100644 --- a/assistants/prospector-assistant/assistant/chat.py +++ b/assistants/prospector-assistant/assistant/chat.py @@ -38,6 +38,7 @@ from . import legacy from .agents.artifact_agent import Artifact, ArtifactAgent, ArtifactConversationInspectorStateProvider from .agents.document_agent import DocumentAgent +from .artifact_creation_extension.extension import ArtifactCreationExtension from .config import AssistantConfigModel from .form_fill_extension import FormFillExtension, LLMConfig @@ -82,6 +83,7 @@ async def content_evaluator_factory(context: ConversationContext) -> ContentSafe attachments_extension = AttachmentsExtension(assistant) form_fill_extension = FormFillExtension(assistant) +artifact_creation_extension = ArtifactCreationExtension(assistant, assistant_config) # # create the FastAPI app instance @@ -132,12 +134,15 @@ async def on_chat_message_created( - @assistant.events.conversation.message.on_created """ + config = await assistant_config.get(context.assistant) + if config.guided_workflow == "Long Document Creation": + return + # update the participant status to indicate the assistant is responding async with send_error_message_on_exception(context), context.set_status("responding..."): # # NOTE: we're experimenting with agents, if they are enabled, use them to respond to the conversation # - config = await assistant_config.get(context.assistant) metadata: dict[str, Any] = {"debug": {"content_safety": event.data.get(content_safety.metadata_key, {})}} match config.guided_workflow: @@ -168,6 +173,7 @@ async def on_conversation_created(context: ConversationContext) -> None: config = await assistant_config.get(context.assistant) metadata: dict[str, Any] = {"debug": {}} + task: asyncio.Task | None = None match config.guided_workflow: case "Form Completion": task = asyncio.create_task(welcome_message_form_fill(context)) @@ -175,12 +181,15 @@ async def on_conversation_created(context: ConversationContext) -> None: task = asyncio.create_task( welcome_message_create_document(config, context, message=None, metadata=metadata) ) + case "Long Document Creation": + pass case _: logger.error("Guided workflow unknown or not supported.") return - background_tasks.add(task) - task.add_done_callback(background_tasks.remove) + if task: + background_tasks.add(task) + task.add_done_callback(background_tasks.remove) async def welcome_message_form_fill(context: ConversationContext) -> None: diff --git a/assistants/prospector-assistant/assistant/config.py b/assistants/prospector-assistant/assistant/config.py index 56219c16..2bcd1cc1 100644 --- a/assistants/prospector-assistant/assistant/config.py +++ b/assistants/prospector-assistant/assistant/config.py @@ -118,12 +118,12 @@ class RequestConfig(BaseModel): # the workbench app builds dynamic forms based on the configuration model and UI schema class AssistantConfigModel(BaseModel): guided_workflow: Annotated[ - Literal["Form Completion", "Document Creation"], + Literal["Form Completion", "Document Creation", "Long Document Creation"], Field( title="Guided Workflow", description="The workflow extension to guide this conversation.", ), - ] = "Document Creation" + ] = "Form Completion" enable_debug_output: Annotated[ bool,