diff --git a/libraries/python/semantic-workbench-api-model/semantic_workbench_api_model/workbench_model.py b/libraries/python/semantic-workbench-api-model/semantic_workbench_api_model/workbench_model.py index 9323a96e..ae735081 100644 --- a/libraries/python/semantic-workbench-api-model/semantic_workbench_api_model/workbench_model.py +++ b/libraries/python/semantic-workbench-api-model/semantic_workbench_api_model/workbench_model.py @@ -161,6 +161,7 @@ class ConversationMessage(BaseModel): content: str filenames: list[str] metadata: dict[str, Any] + has_debug_data: bool @property def command_name(self) -> str: @@ -177,6 +178,11 @@ def command_args(self) -> str: return "".join(self.content.split(" ", 1)[1:]) +class ConversationMessageDebug(BaseModel): + message_id: uuid.UUID + debug_data: dict[str, Any] + + class ConversationMessageList(BaseModel): messages: list[ConversationMessage] @@ -435,6 +441,7 @@ class NewConversationMessage(BaseModel): content_type: str = "text/plain" filenames: list[str] | None = None metadata: dict[str, Any] | None = None + debug_data: dict[str, Any] | None = None class NewConversationShare(BaseModel): diff --git a/libraries/python/semantic-workbench-assistant/tests/test_assistant_app.py b/libraries/python/semantic-workbench-assistant/tests/test_assistant_app.py index cf45c9f3..cc1faca4 100644 --- a/libraries/python/semantic-workbench-assistant/tests/test_assistant_app.py +++ b/libraries/python/semantic-workbench-assistant/tests/test_assistant_app.py @@ -157,6 +157,7 @@ async def on_chat_message( content="Hello, world", filenames=[], metadata={}, + has_debug_data=False, ).model_dump(mode="json") }, ) @@ -184,6 +185,7 @@ async def on_chat_message( content="Hello, world", filenames=[], metadata={}, + has_debug_data=False, ).model_dump(mode="json") }, ) @@ -211,6 +213,7 @@ async def on_chat_message( content="Hello, world", filenames=[], metadata={}, + has_debug_data=False, ).model_dump(mode="json") }, ) diff --git a/workbench-app/src/components/Conversations/DebugInspector.tsx b/workbench-app/src/components/Conversations/DebugInspector.tsx index 124bc976..a8913570 100644 --- a/workbench-app/src/components/Conversations/DebugInspector.tsx +++ b/workbench-app/src/components/Conversations/DebugInspector.tsx @@ -1,10 +1,11 @@ // Copyright (c) Microsoft. All rights reserved. -import { Button, Tooltip, makeStyles } from '@fluentui/react-components'; +import { Button, DialogOpenChangeData, DialogOpenChangeEvent, Tooltip, makeStyles } from '@fluentui/react-components'; import { Info16Regular } from '@fluentui/react-icons'; import React from 'react'; import { JSONTree } from 'react-json-tree'; import { DialogControl } from '../App/DialogControl'; +import { Loading } from '../App/Loading'; import { ContentRenderer } from './ContentRenderers/ContentRenderer'; const useClasses = makeStyles({ @@ -23,13 +24,27 @@ const useClasses = makeStyles({ interface DebugInspectorProps { debug?: { [key: string]: any }; + loading?: boolean; trigger?: JSX.Element; + onOpen?: () => void; + onClose?: () => void; } export const DebugInspector: React.FC = (props) => { - const { debug, trigger } = props; + const { debug, loading, trigger, onOpen, onClose } = props; const classes = useClasses(); + const onOpenChanged = React.useCallback( + (_: DialogOpenChangeEvent, data: DialogOpenChangeData) => { + if (data.open) { + onOpen?.(); + return; + } + onClose?.(); + }, + [onOpen, onClose], + ); + if (!debug) { return null; } @@ -48,8 +63,11 @@ export const DebugInspector: React.FC = (props) => { } classNames={{ dialogSurface: classes.root }} title="Debug Inspection" + onOpenChange={onOpenChanged} content={ - debug.content ? ( + loading ? ( + + ) : debug.content ? ( ) : (
diff --git a/workbench-app/src/components/Conversations/InteractInput.tsx b/workbench-app/src/components/Conversations/InteractInput.tsx index dd05ec31..205752cc 100644 --- a/workbench-app/src/components/Conversations/InteractInput.tsx +++ b/workbench-app/src/components/Conversations/InteractInput.tsx @@ -354,6 +354,7 @@ export const InteractInput: React.FC = (props) => { contentType: 'text/plain', filenames: [], metadata, + hasDebugData: false, }, ]), ); diff --git a/workbench-app/src/components/Conversations/InteractMessage.tsx b/workbench-app/src/components/Conversations/InteractMessage.tsx index c1d6f3fa..3295426f 100644 --- a/workbench-app/src/components/Conversations/InteractMessage.tsx +++ b/workbench-app/src/components/Conversations/InteractMessage.tsx @@ -36,7 +36,10 @@ import { Utility } from '../../libs/Utility'; import { Conversation } from '../../models/Conversation'; import { ConversationMessage } from '../../models/ConversationMessage'; import { ConversationParticipant } from '../../models/ConversationParticipant'; -import { useCreateConversationMessageMutation } from '../../services/workbench'; +import { + useCreateConversationMessageMutation, + useGetConversationMessageDebugDataQuery, +} from '../../services/workbench'; import { CopyButton } from '../App/CopyButton'; import { ContentRenderer } from './ContentRenderers/ContentRenderer'; import { ConversationFileIcon } from './ConversationFileIcon'; @@ -153,6 +156,15 @@ export const InteractMessage: React.FC = (props) => { const { getAvatarData } = useParticipantUtility(); const [createConversationMessage] = useCreateConversationMessageMutation(); const { isMessageVisibleRef, isMessageVisible, isUnread } = useConversationUtility(); + const [skipDebugLoad, setSkipDebugLoad] = React.useState(true); + const { + data: debugData, + isLoading: isLoadingDebugData, + isUninitialized: isUninitializedDebugData, + } = useGetConversationMessageDebugDataQuery( + { conversationId: conversation.id, messageId: message.id }, + { skip: skipDebugLoad }, + ); const isUser = participant.role === 'user'; @@ -228,7 +240,14 @@ export const InteractMessage: React.FC = (props) => { () => ( <> {!readOnly && } - + { + console.log('OPEN!'); + setSkipDebugLoad(false); + }} + /> {!readOnly && ( <> @@ -238,7 +257,7 @@ export const InteractMessage: React.FC = (props) => { )} ), - [conversation, message, readOnly], + [conversation, debugData?.debugData, isLoadingDebugData, isUninitializedDebugData, message, readOnly], ); const getRenderedMessage = React.useCallback(() => { diff --git a/workbench-app/src/models/ConversationMessage.ts b/workbench-app/src/models/ConversationMessage.ts index f049112b..c349faf6 100644 --- a/workbench-app/src/models/ConversationMessage.ts +++ b/workbench-app/src/models/ConversationMessage.ts @@ -14,6 +14,7 @@ export interface ConversationMessage { metadata?: { [key: string]: any; }; + hasDebugData: boolean; } export const conversationMessageFromJSON = (json: any): ConversationMessage => { @@ -29,5 +30,6 @@ export const conversationMessageFromJSON = (json: any): ConversationMessage => { contentType: json.content_type, filenames: json.filenames, metadata: json.metadata, + hasDebugData: json.has_debug_data, }; }; diff --git a/workbench-app/src/models/ConversationMessageDebug.ts b/workbench-app/src/models/ConversationMessageDebug.ts new file mode 100644 index 00000000..23540fe8 --- /dev/null +++ b/workbench-app/src/models/ConversationMessageDebug.ts @@ -0,0 +1,15 @@ +// Copyright (c) Microsoft. All rights reserved. + +export interface ConversationMessageDebug { + id: string; + debugData: { + [key: string]: any; + }; +} + +export const conversationMessageDebugFromJSON = (json: any): ConversationMessageDebug => { + return { + id: json.id, + debugData: json.debug_data, + }; +}; diff --git a/workbench-app/src/services/workbench/conversation.ts b/workbench-app/src/services/workbench/conversation.ts index 3e14335e..ac785565 100644 --- a/workbench-app/src/services/workbench/conversation.ts +++ b/workbench-app/src/services/workbench/conversation.ts @@ -1,5 +1,6 @@ import { Conversation } from '../../models/Conversation'; -import { ConversationMessage } from '../../models/ConversationMessage'; +import { ConversationMessage, conversationMessageFromJSON } from '../../models/ConversationMessage'; +import { ConversationMessageDebug, conversationMessageDebugFromJSON } from '../../models/ConversationMessageDebug'; import { transformResponseToConversationParticipant } from './participant'; import { workbenchApi } from './workbench'; @@ -44,6 +45,14 @@ export const conversationApi = workbenchApi.injectEndpoints({ providesTags: ['Conversation'], transformResponse: (response: any) => transformResponseToConversationMessages(response), }), + getConversationMessageDebugData: builder.query< + ConversationMessageDebug, + { conversationId: string; messageId: string } + >({ + query: ({ conversationId, messageId }) => + `/conversations/${conversationId}/messages/${messageId}/debug_data`, + transformResponse: (response: any) => transformResponseToConversationMessageDebug(response), + }), createConversationMessage: builder.mutation< ConversationMessage, { conversationId: string } & Partial & @@ -80,6 +89,7 @@ export const { useGetAssistantConversationsQuery, useGetConversationQuery, useGetConversationMessagesQuery, + useGetConversationMessageDebugDataQuery, useCreateConversationMessageMutation, useDeleteConversationMessageMutation, } = conversationApi; @@ -118,24 +128,20 @@ const transformResponseToConversationMessages = (response: any): ConversationMes const transformResponseToMessage = (response: any): ConversationMessage => { try { - return { - id: response.id, - sender: { - participantId: response.sender.participant_id, - participantRole: response.sender.participant_role, - }, - timestamp: response.timestamp, - messageType: response.message_type ?? 'chat', - content: response.content, - contentType: response.content_type, - filenames: response.filenames, - metadata: response.metadata, - }; + return conversationMessageFromJSON(response); } catch (error) { throw new Error(`Failed to transform message response: ${error}`); } }; +const transformResponseToConversationMessageDebug = (response: any): ConversationMessageDebug => { + try { + return conversationMessageDebugFromJSON(response); + } catch (error) { + throw new Error(`Failed to transform message debug response: ${error}`); + } +}; + const transformMessageForRequest = (message: Partial) => { const request: Record = { timestamp: message.timestamp, diff --git a/workbench-service/Makefile b/workbench-service/Makefile index 15db9729..54f22f67 100644 --- a/workbench-service/Makefile +++ b/workbench-service/Makefile @@ -6,7 +6,7 @@ WORKBENCH__DB__URL ?= postgresql:///workbench .PHONY: alembic-upgrade-head alembic-upgrade-head: - WORKBENCH__DB__URL="$(WORKBENCH__DB__URL)" alembic upgrade head + WORKBENCH__DB__URL="$(WORKBENCH__DB__URL)" uv run alembic upgrade head .PHONY: alembic-generate-migration alembic-generate-migration: @@ -15,7 +15,7 @@ ifndef migration $(info ex: make alembic-generate-migration migration="neato changes") $(error "migration" is not set) else - WORKBENCH__DB__URL="$(WORKBENCH__DB__URL)" alembic revision --autogenerate -m "$(migration)" + WORKBENCH__DB__URL="$(WORKBENCH__DB__URL)" uv run alembic revision --autogenerate -m "$(migration)" endif DOCKER_PATH = $(repo_root) diff --git a/workbench-service/migrations/versions/2024_11_04_204029_5149c7fb5a32_conversationmessagedebug.py b/workbench-service/migrations/versions/2024_11_04_204029_5149c7fb5a32_conversationmessagedebug.py new file mode 100644 index 00000000..483533d1 --- /dev/null +++ b/workbench-service/migrations/versions/2024_11_04_204029_5149c7fb5a32_conversationmessagedebug.py @@ -0,0 +1,99 @@ +"""conversationmessagedebug + +Revision ID: 5149c7fb5a32 +Revises: 039bec8edc33 +Create Date: 2024-11-04 20:40:29.252951 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +import sqlmodel as sm +from alembic import op +from semantic_workbench_service import db + +# revision identifiers, used by Alembic. +revision: str = "5149c7fb5a32" +down_revision: Union[str, None] = "039bec8edc33" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.create_table( + "conversationmessagedebug", + sa.Column("message_id", sa.Uuid(), nullable=False), + sa.Column("data", sa.JSON(), nullable=False), + sa.ForeignKeyConstraint( + ["message_id"], + ["conversationmessage.message_id"], + name="fk_conversationmessagedebug_message_id_conversationmessage", + ondelete="CASCADE", + ), + sa.PrimaryKeyConstraint("message_id"), + ) + + bind = op.get_bind() + max_sequence = bind.execute(sm.select(sm.func.max(db.ConversationMessage.sequence))).scalar() + if max_sequence is not None: + step = 100 + for sequence_start in range(1, max_sequence + 1, step): + sequence_end_exclusive = sequence_start + step + + results = bind.execute( + sm.select(db.ConversationMessage.message_id, db.ConversationMessage.meta_data).where( + db.ConversationMessage.sequence >= sequence_start, + db.ConversationMessage.sequence < sequence_end_exclusive, + ) + ).fetchall() + + for message_id, meta_data in results: + debug = meta_data.pop("debug", None) + if not debug: + continue + + bind.execute( + sm.insert(db.ConversationMessageDebug).values( + message_id=message_id, + data=debug, + ) + ) + + bind.execute( + sm.update(db.ConversationMessage) + .where(db.ConversationMessage.message_id == message_id) + .values(meta_data=meta_data) + ) + + +def downgrade() -> None: + bind = op.get_bind() + + max_sequence = bind.execute(sm.select(sm.func.max(db.ConversationMessage.sequence))).scalar() + if max_sequence is not None: + step = 100 + for sequence_start in range(1, max_sequence + 1, step): + sequence_end_exclusive = sequence_start + step + results = bind.execute( + sm.select( + db.ConversationMessageDebug.message_id, + db.ConversationMessageDebug.data, + db.ConversationMessage.meta_data, + ) + .join(db.ConversationMessage) + .where( + db.ConversationMessage.sequence >= sequence_start, + db.ConversationMessage.sequence < sequence_end_exclusive, + ) + ).fetchall() + + for message_id, debug_data, meta_data in results: + meta_data["debug"] = debug_data + bind.execute( + sm.update(db.ConversationMessage) + .where(db.ConversationMessage.message_id == message_id) + .values(meta_data=meta_data) + ) + + op.drop_table("conversationmessagedebug") diff --git a/workbench-service/semantic_workbench_service/controller/conversation.py b/workbench-service/semantic_workbench_service/controller/conversation.py index e81bde11..32e7954c 100644 --- a/workbench-service/semantic_workbench_service/controller/conversation.py +++ b/workbench-service/semantic_workbench_service/controller/conversation.py @@ -4,6 +4,7 @@ from dataclasses import dataclass from typing import AsyncContextManager, Awaitable, Callable, Iterable, Literal, Sequence +import deepmerge from semantic_workbench_api_model.assistant_service_client import AssistantError from semantic_workbench_api_model.workbench_model import ( Conversation, @@ -11,6 +12,7 @@ ConversationEventType, ConversationList, ConversationMessage, + ConversationMessageDebug, ConversationMessageList, ConversationParticipant, ConversationParticipantList, @@ -90,7 +92,7 @@ async def create_conversation( async def _projections_with_participants( self, session: AsyncSession, - conversation_projections: Sequence[tuple[db.Conversation, db.ConversationMessage | None, str]], + conversation_projections: Sequence[tuple[db.Conversation, db.ConversationMessage | None, bool, str]], ) -> Iterable[ tuple[ db.Conversation, @@ -98,6 +100,7 @@ async def _projections_with_participants( Iterable[db.AssistantParticipant], dict[uuid.UUID, db.Assistant], db.ConversationMessage | None, + bool, str, ] ]: @@ -137,10 +140,11 @@ def merge() -> Iterable[ Iterable[db.AssistantParticipant], dict[uuid.UUID, db.Assistant], db.ConversationMessage | None, + bool, str, ] ]: - for conversation, latest_message, permission in conversation_projections: + for conversation, latest_message, latest_message_has_debug, permission in conversation_projections: conversation_id = conversation.conversation_id conversation_user_participants = ( up for up in user_participants if up.conversation_id == conversation_id @@ -154,6 +158,7 @@ def merge() -> Iterable[ conversation_assistant_participants, assistants_map, latest_message, + latest_message_has_debug, permission, ) @@ -246,13 +251,20 @@ async def get_conversation( conversation_projections=[conversation_projection], ) - conversation, user_participants, assistant_participants, assistants, latest_message, permission = next( - iter(projections_with_participants) - ) + ( + conversation, + user_participants, + assistant_participants, + assistants, + latest_message, + latest_message_has_debug, + permission, + ) = next(iter(projections_with_participants)) return convert.conversation_from_db( model=conversation, latest_message=latest_message, + latest_message_has_debug=latest_message_has_debug, permission=permission, user_participants=user_participants, assistant_participants=assistant_participants, @@ -658,6 +670,11 @@ async def create_conversation_message( role = "assistant" participant_id = str(principal.assistant_id) + # pop "debug" from metadata, if it exists, and merge with the debug field + message_debug = deepmerge.always_merger.merge( + (new_message.metadata or {}).pop("debug", None), new_message.debug_data or {} + ) + message = db.ConversationMessage( conversation_id=conversation.conversation_id, sender_participant_role=role, @@ -672,10 +689,18 @@ async def create_conversation_message( message.message_id = new_message.id session.add(message) + + if message_debug: + debug = db.ConversationMessageDebug( + message_id=message.message_id, + data=message_debug, + ) + session.add(debug) + await session.commit() await session.refresh(message) - message_response = convert.conversation_message_from_db(message) + message_response = convert.conversation_message_from_db(message, has_debug=bool(message_debug)) # share message with previewers message_preview = MessagePreview(conversation_id=conversation_id, message=message_response) @@ -700,17 +725,36 @@ async def get_message( self, principal: auth.ActorPrincipal, conversation_id: uuid.UUID, message_id: uuid.UUID ) -> ConversationMessage: async with self._get_session() as session: - message = ( + projection = ( await session.exec( - query.select_conversation_messages_for(principal=principal) + query.select_conversation_message_projections_for(principal=principal) .where(db.ConversationMessage.conversation_id == conversation_id) .where(db.ConversationMessage.message_id == message_id) ) ).one_or_none() - if message is None: + if projection is None: + raise exceptions.NotFoundError() + + message, has_debug = projection + + return convert.conversation_message_from_db(message, has_debug=has_debug) + + async def get_message_debug( + self, principal: auth.ActorPrincipal, conversation_id: uuid.UUID, message_id: uuid.UUID + ) -> ConversationMessageDebug: + async with self._get_session() as session: + message_debug = ( + await session.exec( + query.select_conversation_message_debugs_for(principal=principal).where( + db.Conversation.conversation_id == conversation_id, + db.ConversationMessageDebug.message_id == message_id, + ) + ) + ).one_or_none() + if message_debug is None: raise exceptions.NotFoundError() - return convert.conversation_message_from_db(message) + return convert.conversation_message_debug_from_db(message_debug) async def get_messages( self, @@ -734,7 +778,7 @@ async def get_messages( if conversation is None: raise exceptions.NotFoundError() - select_query = query.select_conversation_messages_for(principal=principal).where( + select_query = query.select_conversation_message_projections_for(principal=principal).where( db.ConversationMessage.conversation_id == conversation_id ) @@ -810,7 +854,7 @@ async def delete_message( if message is None: raise exceptions.NotFoundError() - message_response = convert.conversation_message_from_db(message) + message_response = convert.conversation_message_from_db(message, has_debug=False) await session.delete(message) await session.commit() diff --git a/workbench-service/semantic_workbench_service/controller/convert.py b/workbench-service/semantic_workbench_service/controller/convert.py index b717a3ba..f73093b5 100644 --- a/workbench-service/semantic_workbench_service/controller/convert.py +++ b/workbench-service/semantic_workbench_service/controller/convert.py @@ -9,6 +9,7 @@ Conversation, ConversationList, ConversationMessage, + ConversationMessageDebug, ConversationMessageList, ConversationParticipant, ConversationParticipantList, @@ -151,6 +152,7 @@ def conversation_from_db( assistant_participants: Iterable[db.AssistantParticipant], assistants: Mapping[uuid.UUID, db.Assistant], latest_message: db.ConversationMessage | None, + latest_message_has_debug: bool, permission: str, ) -> Conversation: return Conversation( @@ -161,7 +163,9 @@ def conversation_from_db( metadata=model.meta_data, created_datetime=model.created_datetime, conversation_permission=ConversationPermission(permission), - latest_message=conversation_message_from_db(model=latest_message) if latest_message else None, + latest_message=conversation_message_from_db(model=latest_message, has_debug=latest_message_has_debug) + if latest_message + else None, participants=conversation_participant_list_from_db( user_participants=user_participants, assistant_participants=assistant_participants, @@ -178,6 +182,7 @@ def conversation_list_from_db( Iterable[db.AssistantParticipant], dict[uuid.UUID, db.Assistant], db.ConversationMessage | None, + bool, str, ] ], @@ -190,9 +195,10 @@ def conversation_list_from_db( assistant_participants=assistant_participants, assistants=assistants, latest_message=latest_message, + latest_message_has_debug=latest_message_has_debug, permission=permission, ) - for conversation, user_participants, assistant_participants, assistants, latest_message, permission in models + for conversation, user_participants, assistant_participants, assistants, latest_message, latest_message_has_debug, permission in models ] ) @@ -236,7 +242,7 @@ def conversation_share_redemption_list_from_db( ) -def conversation_message_from_db(model: db.ConversationMessage) -> ConversationMessage: +def conversation_message_from_db(model: db.ConversationMessage, has_debug: bool) -> ConversationMessage: return ConversationMessage( id=model.message_id, sender=MessageSender( @@ -249,13 +255,21 @@ def conversation_message_from_db(model: db.ConversationMessage) -> ConversationM content_type=model.content_type, metadata=model.meta_data, filenames=model.filenames, + has_debug_data=has_debug, ) def conversation_message_list_from_db( - models: Iterable[db.ConversationMessage], + models: Iterable[tuple[db.ConversationMessage, bool]], ) -> ConversationMessageList: - return ConversationMessageList(messages=[conversation_message_from_db(m) for m in models]) + return ConversationMessageList(messages=[conversation_message_from_db(m, debug) for m, debug in models]) + + +def conversation_message_debug_from_db(model: db.ConversationMessageDebug) -> ConversationMessageDebug: + return ConversationMessageDebug( + message_id=model.message_id, + debug_data=model.data, + ) def file_from_db(models: tuple[db.File, db.FileVersion]) -> File: diff --git a/workbench-service/semantic_workbench_service/controller/export_import.py b/workbench-service/semantic_workbench_service/controller/export_import.py index ae177461..9e999ecb 100644 --- a/workbench-service/semantic_workbench_service/controller/export_import.py +++ b/workbench-service/semantic_workbench_service/controller/export_import.py @@ -70,6 +70,14 @@ async def export_file( .order_by(col(db.ConversationMessage.sequence).asc()) ) + message_debugs = await session.exec( + select(db.ConversationMessageDebug) + .join(db.ConversationMessage) + .where(col(db.ConversationMessage.conversation_id).in_(conversation_ids)) + .order_by(col(db.ConversationMessage.conversation_id).asc()) + .order_by(col(db.ConversationMessage.sequence).asc()) + ) + user_participants = await session.exec( select(db.UserParticipant) .where(col(db.UserParticipant.conversation_id).in_(conversation_ids)) @@ -93,7 +101,14 @@ def _records(*sources: ScalarResult) -> Generator[_Record, None, None]: f.writelines( _lines_from( _records( - assistants, conversations, messages, user_participants, assistant_participants, files, file_versions + assistants, + conversations, + messages, + message_debugs, + user_participants, + assistant_participants, + files, + file_versions, ) ) ) @@ -106,6 +121,7 @@ def _records(*sources: ScalarResult) -> Generator[_Record, None, None]: class ImportResult: assistant_id_old_to_new: dict[uuid.UUID, uuid.UUID] conversation_id_old_to_new: dict[uuid.UUID, uuid.UUID] + message_id_old_to_new: dict[uuid.UUID, uuid.UUID] assistant_conversation_old_ids: dict[uuid.UUID, set[uuid.UUID]] file_id_old_to_new: dict[uuid.UUID, uuid.UUID] @@ -114,6 +130,7 @@ async def import_files(session: AsyncSession, owner_id: str, files: Iterable[IO[ result = ImportResult( assistant_id_old_to_new={}, conversation_id_old_to_new={}, + message_id_old_to_new={}, assistant_conversation_old_ids=collections.defaultdict(set), file_id_old_to_new={}, ) @@ -223,7 +240,8 @@ async def _process_record(record: _Record) -> None: if conversation_id is None: raise RuntimeError(f"conversation_id {message.conversation_id} is not found") message.conversation_id = conversation_id - message.message_id = uuid.uuid4() + result.message_id_old_to_new[message.message_id] = uuid.uuid4() + message.message_id = result.message_id_old_to_new[message.message_id] if message.sender_participant_role == "assistant": assistant_id = result.assistant_id_old_to_new.get(uuid.UUID(message.sender_participant_id)) @@ -231,6 +249,14 @@ async def _process_record(record: _Record) -> None: message.sender_participant_id = str(assistant_id) session.add(message) + case db.ConversationMessageDebug.__name__: + message_debug = db.ConversationMessageDebug.model_validate(record.data) + message_id = result.message_id_old_to_new.get(message_debug.message_id) + if message_id is None: + raise RuntimeError(f"message_id {message_debug.message_id} is not found") + message_debug.message_id = message_id + session.add(message_debug) + case db.File.__name__: file = db.File.model_validate(record.data) result.file_id_old_to_new[file.file_id] = uuid.uuid4() diff --git a/workbench-service/semantic_workbench_service/db.py b/workbench-service/semantic_workbench_service/db.py index 85b911a6..e13b2ddc 100644 --- a/workbench-service/semantic_workbench_service/db.py +++ b/workbench-service/semantic_workbench_service/db.py @@ -304,6 +304,24 @@ class ConversationMessage(SQLModel, table=True): related_conversation: Conversation = Relationship() +class ConversationMessageDebug(SQLModel, table=True): + message_id: uuid.UUID = Field( + sa_column=sqlalchemy.Column( + sqlalchemy.ForeignKey( + "conversationmessage.message_id", + name="fk_conversationmessagedebug_message_id_conversationmessage", + ondelete="CASCADE", + ), + nullable=False, + primary_key=True, + ), + ) + data: dict[str, Any] = Field(sa_column=sqlalchemy.Column(sqlalchemy.JSON, nullable=False), default={}) + + # this relationship is needed to enforce correct INSERT order by SQLModel + related_messag: ConversationMessage = Relationship() + + class File(SQLModel, table=True): file_id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True) conversation_id: uuid.UUID = Field( diff --git a/workbench-service/semantic_workbench_service/query.py b/workbench-service/semantic_workbench_service/query.py index c15a4bb3..7c8e01a7 100644 --- a/workbench-service/semantic_workbench_service/query.py +++ b/workbench-service/semantic_workbench_service/query.py @@ -1,9 +1,9 @@ from typing import Any, TypeVar +from semantic_workbench_api_model.workbench_model import MessageType from sqlalchemy import Function from sqlmodel import String, and_, cast, col, func, literal, or_, select from sqlmodel.sql.expression import Select, SelectOfScalar -from semantic_workbench_api_model.workbench_model import MessageType from . import auth, db, settings @@ -122,12 +122,13 @@ def select_conversation_projections_for( latest_message_types: set[MessageType], include_all_owned: bool = False, include_observer: bool = False, -) -> Select[tuple[db.Conversation, db.ConversationMessage | None, str]]: +) -> Select[tuple[db.Conversation, db.ConversationMessage | None, bool, str]]: match principal: case auth.UserPrincipal(): select_query = select( db.Conversation, db.ConversationMessage, + col(db.ConversationMessageDebug.message_id).is_not(None).label("has_debug"), db.UserParticipant.conversation_permission, ) @@ -135,6 +136,7 @@ def select_conversation_projections_for( select_query = select( db.Conversation, db.ConversationMessage, + col(db.ConversationMessageDebug.message_id).is_not(None).label("has_debug"), literal("read_write").label("conversation_permission"), ) @@ -155,27 +157,73 @@ def select_conversation_projections_for( .subquery() ) - return query.join_from( - db.Conversation, - latest_message_subquery, - onclause=col(db.Conversation.conversation_id) == col(latest_message_subquery.c.conversation_id), - isouter=True, - ).join_from( - db.Conversation, - db.ConversationMessage, - onclause=and_( - col(db.Conversation.conversation_id) == col(db.ConversationMessage.conversation_id), - col(db.ConversationMessage.sequence) == col(latest_message_subquery.c.latest_message_sequence), - ), - isouter=True, + return ( + query.join_from( + db.Conversation, + latest_message_subquery, + onclause=col(db.Conversation.conversation_id) == col(latest_message_subquery.c.conversation_id), + isouter=True, + ) + .join_from( + db.Conversation, + db.ConversationMessage, + onclause=and_( + col(db.Conversation.conversation_id) == col(db.ConversationMessage.conversation_id), + col(db.ConversationMessage.sequence) == col(latest_message_subquery.c.latest_message_sequence), + ), + isouter=True, + ) + .join_from( + db.ConversationMessage, + db.ConversationMessageDebug, + isouter=True, + ) ) +def _select_conversation_messages_for( + select_query: SelectT, + principal: auth.ActorPrincipal, +) -> SelectT: + match principal: + case auth.UserPrincipal(): + return ( + select_query.join(db.Conversation) + .join(db.UserParticipant) + .where(db.UserParticipant.user_id == principal.user_id) + ) + + case auth.AssistantPrincipal(): + return ( + select_query.join(db.Conversation) + .join(db.AssistantParticipant) + .where(db.AssistantParticipant.assistant_id == principal.assistant_id) + ) + + def select_conversation_messages_for(principal: auth.ActorPrincipal) -> SelectOfScalar[db.ConversationMessage]: + return _select_conversation_messages_for(select(db.ConversationMessage), principal) + + +def select_conversation_message_projections_for( + principal: auth.ActorPrincipal, +) -> Select[tuple[db.ConversationMessage, bool]]: + return _select_conversation_messages_for( + select(db.ConversationMessage, col(db.ConversationMessageDebug.message_id).is_not(None)).join( + db.ConversationMessageDebug, isouter=True + ), + principal, + ) + + +def select_conversation_message_debugs_for( + principal: auth.ActorPrincipal, +) -> SelectOfScalar[db.ConversationMessageDebug]: match principal: case auth.UserPrincipal(): return ( - select(db.ConversationMessage) + select(db.ConversationMessageDebug) + .join(db.ConversationMessage) .join(db.Conversation) .join(db.UserParticipant) .where(db.UserParticipant.user_id == principal.user_id) @@ -183,7 +231,8 @@ def select_conversation_messages_for(principal: auth.ActorPrincipal) -> SelectOf case auth.AssistantPrincipal(): return ( - select(db.ConversationMessage) + select(db.ConversationMessageDebug) + .join(db.ConversationMessage) .join(db.Conversation) .join(db.AssistantParticipant) .where(db.AssistantParticipant.assistant_id == principal.assistant_id) diff --git a/workbench-service/semantic_workbench_service/service.py b/workbench-service/semantic_workbench_service/service.py index fcdc0aa9..34472fa4 100644 --- a/workbench-service/semantic_workbench_service/service.py +++ b/workbench-service/semantic_workbench_service/service.py @@ -52,6 +52,7 @@ ConversationImportResult, ConversationList, ConversationMessage, + ConversationMessageDebug, ConversationMessageList, ConversationParticipant, ConversationParticipantList, @@ -896,6 +897,20 @@ async def get_message( principal=principal, ) + @app.get( + "/conversations/{conversation_id}/messages/{message_id}/debug_data", + ) + async def get_message_debug_data( + conversation_id: uuid.UUID, + message_id: uuid.UUID, + principal: auth.DependsActorPrincipal, + ) -> ConversationMessageDebug: + return await conversation_controller.get_message_debug( + conversation_id=conversation_id, + message_id=message_id, + principal=principal, + ) + @app.delete( "/conversations/{conversation_id}/messages/{message_id}", status_code=status.HTTP_204_NO_CONTENT, diff --git a/workbench-service/tests/test_workbench_service.py b/workbench-service/tests/test_workbench_service.py index 70c33bb9..b520d8ac 100644 --- a/workbench-service/tests/test_workbench_service.py +++ b/workbench-service/tests/test_workbench_service.py @@ -640,6 +640,7 @@ def test_create_conversation_send_user_message(workbench_service: FastAPI, test_ assert httpx.codes.is_success(http_response.status_code) message = workbench_model.ConversationMessage.model_validate(http_response.json()) message_id = message.id + assert message.has_debug_data is False http_response = client.get(f"/conversations/{conversation_id}/messages") assert httpx.codes.is_success(http_response.status_code) @@ -655,14 +656,28 @@ def test_create_conversation_send_user_message(workbench_service: FastAPI, test_ assert message.content == "hello" assert message.sender.participant_id == test_user.id - # send chat another message - payload = {"content": "hello again"} + # send another chat message, with debug + payload = { + "content": "hello again", + "metadata": {"debug": {"key1": "value1"}}, + "debug_data": {"key2": "value2"}, + } http_response = client.post(f"/conversations/{conversation_id}/messages", json=payload) logging.info("response: %s", http_response.json()) assert httpx.codes.is_success(http_response.status_code) message = workbench_model.ConversationMessage.model_validate(http_response.json()) message_two_id = message.id + # debug should be stripped out + assert message.metadata == {} + assert message.has_debug_data is True + + http_response = client.get(f"/conversations/{conversation_id}/messages/{message_two_id}/debug_data") + assert httpx.codes.is_success(http_response.status_code) + message = workbench_model.ConversationMessageDebug.model_validate(http_response.json()) + + assert message.debug_data == {"key1": "value1", "key2": "value2"} + # send a log message payload = {"content": "hello again", "message_type": "log"} http_response = client.post(f"/conversations/{conversation_id}/messages", json=payload) @@ -1112,7 +1127,7 @@ def test_create_assistant_conversations_export_import_conversations( http_response = client.put(f"/conversations/{conversation_id_2}/participants/{assistant_id_1}", json={}) assert httpx.codes.is_success(http_response.status_code) - payload = {"content": "hello"} + payload = {"content": "hello", "debug_data": {"key": "value"}} http_response = client.post(f"/conversations/{conversation_id_1}/messages", json=payload) assert httpx.codes.is_success(http_response.status_code) @@ -1172,6 +1187,23 @@ def test_create_assistant_conversations_export_import_conversations( assert conversations.conversations[4].title == "test-conversation-2 (1)" assert conversations.conversations[5].title == "test-conversation-2 (2)" + for conversation in conversations.conversations: + http_response = client.get(f"/conversations/{conversation.id}/messages") + assert httpx.codes.is_success(http_response.status_code) + + messages = workbench_model.ConversationMessageList.model_validate(http_response.json()) + assert len(messages.messages) == 1 + + message = messages.messages[0] + assert message.content == "hello" + assert message.sender.participant_id == test_user.id + assert message.has_debug_data is True + + http_response = client.get(f"/conversations/{conversation.id}/messages/{message.id}/debug_data") + assert httpx.codes.is_success(http_response.status_code) + message_debug = workbench_model.ConversationMessageDebug.model_validate(http_response.json()) + assert message_debug.debug_data == {"key": "value"} + def test_export_import_conversations_with_files( workbench_service: FastAPI,