diff --git a/apps/nextjs/src/app/api/chat/chatHandler.ts b/apps/nextjs/src/app/api/chat/chatHandler.ts index b998f6e3e..5b85a3d5f 100644 --- a/apps/nextjs/src/app/api/chat/chatHandler.ts +++ b/apps/nextjs/src/app/api/chat/chatHandler.ts @@ -12,6 +12,7 @@ import { PosthogAnalyticsAdapter, } from "@oakai/aila/src/features/analytics"; import { AilaRag } from "@oakai/aila/src/features/rag/AilaRag"; +import type { AilaThreatDetector } from "@oakai/aila/src/features/threatDetection"; import { HeliconeThreatDetector } from "@oakai/aila/src/features/threatDetection/detectors/helicone/HeliconeThreatDetector"; import { LakeraThreatDetector } from "@oakai/aila/src/features/threatDetection/detectors/lakera/LakeraThreatDetector"; import type { LooseLessonPlan } from "@oakai/aila/src/protocol/schema"; @@ -20,6 +21,7 @@ import { withTelemetry } from "@oakai/core/src/tracing/serverTracing"; import type { PrismaClientWithAccelerate } from "@oakai/db"; import { prisma as globalPrisma } from "@oakai/db/client"; import { aiLogger } from "@oakai/logger"; +import { captureException } from "@sentry/nextjs"; import type { NextRequest } from "next/server"; import invariant from "tiny-invariant"; @@ -178,30 +180,63 @@ function isValidLessonPlan(lessonPlan: unknown): boolean { return lessonPlan !== null && typeof lessonPlan === "object"; } -function parseLessonPlanFromOutput(output: unknown): LooseLessonPlan { - if (!output) return {}; +function hasMessages(obj: unknown): obj is { messages: unknown } { + return obj !== null && typeof obj === "object" && "messages" in obj; +} + +function isValidMessages(messages: unknown): boolean { + return Array.isArray(messages); +} + +function verifyChatOwnership( + chat: { userId: string }, + requestUserId: string, + chatId: string, +): void { + if (chat.userId !== requestUserId) { + log.error( + `User ${requestUserId} attempted to access chat ${chatId} which belongs to ${chat.userId}`, + ); + throw new Error("Unauthorized access to chat"); + } +} + +function parseChatOutput( + output: unknown, + chatId: string, +): { messages: Message[]; lessonPlan: LooseLessonPlan } { + let messages: Message[] = []; + let lessonPlan: LooseLessonPlan = {}; try { const parsedOutput = typeof output === "string" ? JSON.parse(output) : output; + if (hasMessages(parsedOutput) && isValidMessages(parsedOutput.messages)) { + messages = parsedOutput.messages as Message[]; + } + if ( hasLessonPlan(parsedOutput) && isValidLessonPlan(parsedOutput.lessonPlan) ) { - return parsedOutput.lessonPlan as LooseLessonPlan; + lessonPlan = parsedOutput.lessonPlan as LooseLessonPlan; } } catch (error) { - log.error("Error parsing output to extract lesson plan", error); + log.error(`Error parsing output for chat ${chatId}`, error); + captureException(error, { + extra: { chatId, output }, + tags: { context: "parseChatOutput" }, + }); } - return {}; + return { messages, lessonPlan }; } -async function loadLessonPlanFromDatabase( +async function loadChatDataFromDatabase( chatId: string, userId: string, -): Promise { +): Promise<{ messages: Message[]; lessonPlan: LooseLessonPlan }> { try { const chat = await prisma.appSession.findUnique({ where: { id: chatId }, @@ -214,25 +249,105 @@ async function loadLessonPlanFromDatabase( if (!chat) { log.info(`No existing chat found for id: ${chatId}`); - return {}; + return { messages: [], lessonPlan: {} }; } - if (chat.userId !== userId) { - log.error( - `User ${userId} attempted to access chat ${chatId} which belongs to ${chat.userId}`, - ); - throw new Error("Unauthorized access to chat"); - } + verifyChatOwnership(chat, userId, chatId); + + const { messages, lessonPlan } = parseChatOutput(chat.output, chatId); - const lessonPlan = parseLessonPlanFromOutput(chat.output); - log.info(`Loaded lesson plan for chat ${chatId}`); - return lessonPlan; + log.info( + `Loaded ${messages.length} messages and lesson plan for chat ${chatId}`, + ); + return { messages, lessonPlan }; } catch (error) { - log.error(`Error loading lesson plan for chat ${chatId}`, error); - return {}; + log.error(`Error loading chat data for chat ${chatId}`, error); + captureException(error, { + extra: { chatId, userId }, + tags: { context: "loadChatDataFromDatabase" }, + }); + throw error; } } +function extractLatestUserMessage(frontendMessages: Message[]): Message | null { + return (frontendMessages ?? []).findLast((m) => m?.role === "user") ?? null; +} + +function prepareMessages( + dbMessages: Message[], + frontendMessages: Message[], + chatId: string, +): Message[] { + const latestUserMessage = extractLatestUserMessage(frontendMessages); + + let messages = [...dbMessages]; + if ( + latestUserMessage && + !messages.some((m) => m.id === latestUserMessage.id) + ) { + messages.push(latestUserMessage); + log.info(`Appended new user message to history for chat ${chatId}`); + } + + return messages; +} + +type CreateAilaInstanceArguments = { + config: Config; + options: AilaOptions; + chatId: string; + userId: string | undefined; + messages: Message[]; + lessonPlan: LooseLessonPlan; + llmService: ReturnType; + moderationAiClient: ReturnType; + threatDetectors: AilaThreatDetector[]; +}; + +async function createAilaInstance({ + config, + options, + chatId, + userId, + messages, + lessonPlan, + llmService, + moderationAiClient, + threatDetectors, +}: CreateAilaInstanceArguments): Promise { + return await withTelemetry( + "chat-create-aila", + { chat_id: chatId, user_id: userId }, + async (): Promise => { + const ailaOptions: Partial = { + options, + chat: { + id: chatId, + userId, + messages, + }, + services: { + chatLlmService: llmService, + moderationAiClient, + ragService: (aila: AilaServices) => new AilaRag({ aila }), + americanismsService: () => new AilaAmericanisms(), + analyticsAdapters: (aila: AilaServices) => [ + new PosthogAnalyticsAdapter(aila), + new DatadogAnalyticsAdapter(aila), + ], + threatDetectors: () => threatDetectors, + }, + document: { + content: lessonPlan ?? {}, + }, + }; + const result = await config.createAila(ailaOptions); + return result; + }, + ); +} + export async function handleChatPostRequest( req: NextRequest, config: Config, @@ -240,7 +355,7 @@ export async function handleChatPostRequest( return await withTelemetry("chat-api", {}, async (span: TracingSpan) => { const { chatId, - messages, + messages: frontendMessages, options, llmService, moderationAiClient, @@ -254,7 +369,10 @@ export async function handleChatPostRequest( userId = await fetchAndCheckUser(chatId); span.setTag("user_id", userId); - const dbLessonPlan = await loadLessonPlanFromDatabase(chatId, userId); + const { messages: dbMessages, lessonPlan: dbLessonPlan } = + await loadChatDataFromDatabase(chatId, userId); + + const messages = prepareMessages(dbMessages, frontendMessages, chatId); setTelemetryMetadata({ span, @@ -264,38 +382,17 @@ export async function handleChatPostRequest( options, }); - aila = await withTelemetry( - "chat-create-aila", - { chat_id: chatId, user_id: userId }, - async (): Promise => { - const ailaOptions: Partial = { - options, - chat: { - id: chatId, - userId, - messages, - }, - services: { - chatLlmService: llmService, - moderationAiClient, - ragService: (aila: AilaServices) => new AilaRag({ aila }), - americanismsService: () => - new AilaAmericanisms(), - analyticsAdapters: (aila: AilaServices) => [ - new PosthogAnalyticsAdapter(aila), - new DatadogAnalyticsAdapter(aila), - ], - threatDetectors: () => threatDetectors, - }, - - document: { - content: dbLessonPlan ?? {}, - }, - }; - const result = await config.createAila(ailaOptions); - return result; - }, - ); + aila = await createAilaInstance({ + config, + options, + chatId, + userId, + messages, + lessonPlan: dbLessonPlan, + llmService, + moderationAiClient, + threatDetectors, + }); invariant(aila, "Aila instance is required"); const abortController = handleConnectionAborted(req); diff --git a/apps/nextjs/src/components/AppComponents/Chat/AiSdk.tsx b/apps/nextjs/src/components/AppComponents/Chat/AiSdk.tsx index fe7159b3b..e2ef36b7c 100644 --- a/apps/nextjs/src/components/AppComponents/Chat/AiSdk.tsx +++ b/apps/nextjs/src/components/AppComponents/Chat/AiSdk.tsx @@ -45,7 +45,6 @@ export function AiSdk({ id }: Readonly) { const [hasFinished, setHasFinished] = useState(true); const initialMessages = useChatStore((state) => state.initialMessages); - const lessonPlan = useLessonPlanStore((state) => state.lessonPlan); const streamingFinished = useChatStore((state) => state.streamingFinished); const scrollToBottom = useChatStore((state) => state.scrollToBottom); const messageStarted = useLessonPlanStore((state) => state.messageStarted); diff --git a/packages/aila/src/core/document/AilaHomework.ts b/packages/aila/src/core/document/AilaHomework.ts deleted file mode 100644 index e69de29bb..000000000 diff --git a/packages/aila/src/core/document/AilaLessonPlan.ts b/packages/aila/src/core/document/AilaLessonPlan.ts deleted file mode 100644 index e69de29bb..000000000