Skip to content

Commit

Permalink
Includes participants in conversation gets
Browse files Browse the repository at this point in the history
And adds parameter to specify what types of messages you care about when
returning the latest message
  • Loading branch information
markwaddle committed Oct 30, 2024
1 parent 8dd28f4 commit 7cf83de
Show file tree
Hide file tree
Showing 7 changed files with 167 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,10 @@ class Conversation(BaseModel):
imported_from_conversation_id: uuid.UUID | None
metadata: dict[str, Any]
created_datetime: datetime.datetime

conversation_permission: ConversationPermission
latest_message: ConversationMessage | None
participants: list[ConversationParticipant]


class ConversationList(BaseModel):
Expand Down
124 changes: 114 additions & 10 deletions workbench-service/semantic_workbench_service/controller/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
import uuid
from dataclasses import dataclass
from typing import AsyncContextManager, Awaitable, Callable, Literal
from typing import AsyncContextManager, Awaitable, Callable, Iterable, Literal, Sequence

from semantic_workbench_api_model.assistant_service_client import AssistantError
from semantic_workbench_api_model.workbench_model import (
Expand Down Expand Up @@ -83,11 +83,86 @@ async def create_conversation(
await session.commit()
await session.refresh(conversation)

return await self.get_conversation(conversation_id=conversation.conversation_id, principal=user_principal)
return await self.get_conversation(
conversation_id=conversation.conversation_id, principal=user_principal, latest_message_types=set()
)

async def _projections_with_participants(
self,
session: AsyncSession,
conversation_projections: Sequence[tuple[db.Conversation, db.ConversationMessage | None, str]],
) -> Iterable[
tuple[
db.Conversation,
Iterable[db.UserParticipant],
Iterable[db.AssistantParticipant],
dict[uuid.UUID, db.Assistant],
db.ConversationMessage | None,
str,
]
]:
user_participants = (
await session.exec(
select(db.UserParticipant).where(
col(db.UserParticipant.conversation_id).in_([
c[0].conversation_id for c in conversation_projections
])
)
)
).all()

assistant_participants = (
await session.exec(
select(db.AssistantParticipant).where(
col(db.AssistantParticipant.conversation_id).in_([
c[0].conversation_id for c in conversation_projections
])
)
)
).all()

assistants = (
await session.exec(
select(db.Assistant).where(
col(db.Assistant.assistant_id).in_([p.assistant_id for p in assistant_participants])
)
)
).all()
assistants_map = {assistant.assistant_id: assistant for assistant in assistants}

def merge() -> Iterable[
tuple[
db.Conversation,
Iterable[db.UserParticipant],
Iterable[db.AssistantParticipant],
dict[uuid.UUID, db.Assistant],
db.ConversationMessage | None,
str,
]
]:
for conversation, latest_message, permission in conversation_projections:
conversation_id = conversation.conversation_id
conversation_user_participants = (
up for up in user_participants if up.conversation_id == conversation_id
)
conversation_assistant_participants = (
ap for ap in assistant_participants if ap.conversation_id == conversation_id
)
yield (
conversation,
conversation_user_participants,
conversation_assistant_participants,
assistants_map,
latest_message,
permission,
)

return merge()

async def get_conversations(
self,
principal: auth.ActorPrincipal,
latest_message_types: set[MessageType],
include_all_owned: bool = False,
) -> ConversationList:
async with self._get_session() as session:
Expand All @@ -96,17 +171,25 @@ async def get_conversations(
conversation_projections = (
await session.exec(
query.select_conversation_projections_for(
principal=principal, include_all_owned=include_all_owned, include_observer=True
principal=principal,
include_all_owned=include_all_owned,
include_observer=True,
latest_message_types=latest_message_types,
).order_by(col(db.Conversation.created_datetime).desc())
)
).all()

return convert.conversation_list_from_db(models=conversation_projections)
projections_with_participants = await self._projections_with_participants(
session=session, conversation_projections=conversation_projections
)

return convert.conversation_list_from_db(models=projections_with_participants)

async def get_assistant_conversations(
self,
user_principal: auth.UserPrincipal,
assistant_id: uuid.UUID,
latest_message_types: set[MessageType],
) -> ConversationList:
async with self._get_session() as session:
assistant = (
Expand All @@ -119,40 +202,61 @@ async def get_assistant_conversations(
if assistant is None:
raise exceptions.NotFoundError()

conversations = (
conversation_projections = (
await session.exec(
query.select_conversation_projections_for(
principal=auth.AssistantPrincipal(
assistant_service_id=assistant.assistant_service_id, assistant_id=assistant_id
),
latest_message_types=latest_message_types,
)
)
).all()

return convert.conversation_list_from_db(models=conversations)
projections_with_participants = await self._projections_with_participants(
session=session, conversation_projections=conversation_projections
)

return convert.conversation_list_from_db(models=projections_with_participants)

async def get_conversation(
self,
conversation_id: uuid.UUID,
principal: auth.ActorPrincipal,
latest_message_types: set[MessageType],
) -> Conversation:
async with self._get_session() as session:
include_all_owned = isinstance(principal, auth.UserPrincipal)

conversation_projection = (
await session.exec(
query.select_conversation_projections_for(
principal=principal, include_all_owned=include_all_owned, include_observer=True
principal=principal,
include_all_owned=include_all_owned,
include_observer=True,
latest_message_types=latest_message_types,
).where(db.Conversation.conversation_id == conversation_id)
)
).one_or_none()
if conversation_projection is None:
raise exceptions.NotFoundError()

conversation, latest_message, permission = conversation_projection
projections_with_participants = await self._projections_with_participants(
session=session,
conversation_projections=[conversation_projection],
)

conversation, user_participants, assistant_participants, assistants, latest_message, permission = next(
iter(projections_with_participants)
)

return convert.conversation_from_db(
model=conversation, latest_message=latest_message, permission=permission
model=conversation,
latest_message=latest_message,
permission=permission,
user_participants=user_participants,
assistant_participants=assistant_participants,
assistants=assistants,
)

async def update_conversation(
Expand Down Expand Up @@ -190,7 +294,7 @@ async def update_conversation(
await session.refresh(conversation)

conversation_model = await self.get_conversation(
conversation_id=conversation.conversation_id, principal=user_principal
conversation_id=conversation.conversation_id, principal=user_principal, latest_message_types=set()
)

await self._notify_event(
Expand Down
24 changes: 22 additions & 2 deletions workbench-service/semantic_workbench_service/controller/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,9 @@ def conversation_participant_list_from_db(

def conversation_from_db(
model: db.Conversation,
user_participants: Iterable[db.UserParticipant],
assistant_participants: Iterable[db.AssistantParticipant],
assistants: Mapping[uuid.UUID, db.Assistant],
latest_message: db.ConversationMessage | None,
permission: str,
) -> Conversation:
Expand All @@ -159,20 +162,37 @@ def conversation_from_db(
created_datetime=model.created_datetime,
conversation_permission=ConversationPermission(permission),
latest_message=conversation_message_from_db(model=latest_message) if latest_message else None,
participants=conversation_participant_list_from_db(
user_participants=user_participants,
assistant_participants=assistant_participants,
assistants=assistants,
).participants,
)


def conversation_list_from_db(
models: Iterable[tuple[db.Conversation, db.ConversationMessage | None, str]],
models: Iterable[
tuple[
db.Conversation,
Iterable[db.UserParticipant],
Iterable[db.AssistantParticipant],
dict[uuid.UUID, db.Assistant],
db.ConversationMessage | None,
str,
]
],
) -> ConversationList:
return ConversationList(
conversations=[
conversation_from_db(
model=conversation,
user_participants=user_participants,
assistant_participants=assistant_participants,
assistants=assistants,
latest_message=latest_message,
permission=permission,
)
for conversation, latest_message, permission in models
for conversation, user_participants, assistant_participants, assistants, latest_message, permission in models
]
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,7 @@ async def ensure_configuration_of_conversation_for_workflow_state(
conversation = await self._conversation_controller.get_conversation(
conversation_id=uuid.UUID(conversation_id),
principal=service_user_principals.workflow,
latest_message_types=set(),
)
except Exception as e:
raise exceptions.RuntimeError(
Expand Down Expand Up @@ -1607,6 +1608,7 @@ async def update_conversation_title(
conversation = await self._conversation_controller.get_conversation(
conversation_id=conversation_id,
principal=service_user_principals.workflow,
latest_message_types=set(),
)
except Exception as e:
raise exceptions.RuntimeError(
Expand Down
3 changes: 3 additions & 0 deletions workbench-service/semantic_workbench_service/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from sqlalchemy import Function
from sqlmodel import String, and_, cast, col, func, literal, or_, select
from sqlmodel.sql.expression import Select, SelectOfScalar
from semantic_workbench_api_model.workbench_model import MessageType

from . import auth, db, settings

Expand Down Expand Up @@ -118,6 +119,7 @@ def select_conversations_for(

def select_conversation_projections_for(
principal: auth.ActorPrincipal,
latest_message_types: set[MessageType],
include_all_owned: bool = False,
include_observer: bool = False,
) -> Select[tuple[db.Conversation, db.ConversationMessage | None, str]]:
Expand Down Expand Up @@ -148,6 +150,7 @@ def select_conversation_projections_for(
db.ConversationMessage.conversation_id,
func.max(db.ConversationMessage.sequence).label("latest_message_sequence"),
)
.where(col(db.ConversationMessage.message_type).in_(latest_message_types))
.group_by(col(db.ConversationMessage.conversation_id))
.subquery()
)
Expand Down
18 changes: 16 additions & 2 deletions workbench-service/semantic_workbench_service/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,13 @@ async def _notify_event(queue_item: ConversationEventQueueItem) -> None:
queue_item.event.id,
)

if queue_item.event.event == ConversationEventType.message_created:
if queue_item.event.event in [
ConversationEventType.message_created,
ConversationEventType.message_deleted,
ConversationEventType.conversation_updated,
ConversationEventType.participant_created,
ConversationEventType.participant_updated,
]:
task = asyncio.create_task(
_notify_user_event(queue_item.event.conversation_id), name="notify_user_event"
)
Expand Down Expand Up @@ -604,18 +610,22 @@ async def delete_assistant(
async def get_assistant_conversations(
assistant_id: uuid.UUID,
user_principal: auth.DependsUserPrincipal,
latest_message_types: Annotated[list[MessageType], Query(alias="latest_message_type")] = [MessageType.chat],
) -> ConversationList:
return await conversation_controller.get_assistant_conversations(
user_principal=user_principal,
assistant_id=assistant_id,
latest_message_types=set(latest_message_types),
)

@app.get("/conversations/{conversation_id}/events")
async def conversation_server_sent_events(
conversation_id: uuid.UUID, request: Request, user_principal: auth.DependsUserPrincipal
) -> EventSourceResponse:
# ensure the conversation exists
await conversation_controller.get_conversation(conversation_id=conversation_id, principal=user_principal)
await conversation_controller.get_conversation(
conversation_id=conversation_id, principal=user_principal, latest_message_types=set()
)

logger.debug(
"client connected to sse; user_id: %s, conversation_id: %s", user_principal.user_id, conversation_id
Expand Down Expand Up @@ -753,20 +763,24 @@ async def create_conversation(
async def list_conversations(
principal: auth.DependsActorPrincipal,
include_inactive: bool = False,
latest_message_types: Annotated[list[MessageType], Query(alias="latest_message_type")] = [MessageType.chat],
) -> ConversationList:
return await conversation_controller.get_conversations(
principal=principal,
include_all_owned=include_inactive,
latest_message_types=set(latest_message_types),
)

@app.get("/conversations/{conversation_id}")
async def get_conversation(
conversation_id: uuid.UUID,
principal: auth.DependsActorPrincipal,
latest_message_types: Annotated[list[MessageType], Query(alias="latest_message_type")] = [MessageType.chat],
) -> Conversation:
return await conversation_controller.get_conversation(
principal=principal,
conversation_id=conversation_id,
latest_message_types=set(latest_message_types),
)

@app.patch("/conversations/{conversation_id}")
Expand Down
9 changes: 8 additions & 1 deletion workbench-service/tests/test_workbench_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,11 +734,18 @@ def test_create_conversation_send_user_message(workbench_service: FastAPI, test_
messages = workbench_model.ConversationMessageList.model_validate(http_response.json())
assert len(messages.messages) == 3

# check latest message in conversation
# check latest chat message in conversation (chat is default)
http_response = client.get(f"/conversations/{conversation_id}")
assert httpx.codes.is_success(http_response.status_code)
conversation = workbench_model.Conversation.model_validate(http_response.json())
assert conversation.latest_message is not None
assert conversation.latest_message.id == message_two_id

# check latest log message in conversation
http_response = client.get(f"/conversations/{conversation_id}", params={"latest_message_type": ["log"]})
assert httpx.codes.is_success(http_response.status_code)
conversation = workbench_model.Conversation.model_validate(http_response.json())
assert conversation.latest_message is not None
assert conversation.latest_message.id == message_log_id


Expand Down

0 comments on commit 7cf83de

Please sign in to comment.