-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: do not trust front end message history #570
Changes from all commits
8abcd8d
3edac78
5945091
6a0f82a
a5725e2
2c09377
8adbd36
4f21ab0
77caad8
eee4264
9020f33
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,6 +12,7 @@ import { | |
PosthogAnalyticsAdapter, | ||
} from "@oakai/aila/src/features/analytics"; | ||
import { AilaRag } from "@oakai/aila/src/features/rag/AilaRag"; | ||
import type { AilaThreatDetector } from "@oakai/aila/src/features/threatDetection"; | ||
import { HeliconeThreatDetector } from "@oakai/aila/src/features/threatDetection/detectors/helicone/HeliconeThreatDetector"; | ||
import { LakeraThreatDetector } from "@oakai/aila/src/features/threatDetection/detectors/lakera/LakeraThreatDetector"; | ||
import type { LooseLessonPlan } from "@oakai/aila/src/protocol/schema"; | ||
|
@@ -20,6 +21,7 @@ import { withTelemetry } from "@oakai/core/src/tracing/serverTracing"; | |
import type { PrismaClientWithAccelerate } from "@oakai/db"; | ||
import { prisma as globalPrisma } from "@oakai/db/client"; | ||
import { aiLogger } from "@oakai/logger"; | ||
import { captureException } from "@sentry/nextjs"; | ||
import type { NextRequest } from "next/server"; | ||
import invariant from "tiny-invariant"; | ||
|
||
|
@@ -178,30 +180,63 @@ function isValidLessonPlan(lessonPlan: unknown): boolean { | |
return lessonPlan !== null && typeof lessonPlan === "object"; | ||
} | ||
|
||
function parseLessonPlanFromOutput(output: unknown): LooseLessonPlan { | ||
if (!output) return {}; | ||
function hasMessages(obj: unknown): obj is { messages: unknown } { | ||
return obj !== null && typeof obj === "object" && "messages" in obj; | ||
} | ||
|
||
function isValidMessages(messages: unknown): boolean { | ||
return Array.isArray(messages); | ||
} | ||
|
||
function verifyChatOwnership( | ||
chat: { userId: string }, | ||
requestUserId: string, | ||
chatId: string, | ||
): void { | ||
if (chat.userId !== requestUserId) { | ||
log.error( | ||
`User ${requestUserId} attempted to access chat ${chatId} which belongs to ${chat.userId}`, | ||
); | ||
throw new Error("Unauthorized access to chat"); | ||
} | ||
} | ||
|
||
function parseChatOutput( | ||
output: unknown, | ||
chatId: string, | ||
): { messages: Message[]; lessonPlan: LooseLessonPlan } { | ||
let messages: Message[] = []; | ||
let lessonPlan: LooseLessonPlan = {}; | ||
|
||
try { | ||
const parsedOutput = | ||
typeof output === "string" ? JSON.parse(output) : output; | ||
|
||
if (hasMessages(parsedOutput) && isValidMessages(parsedOutput.messages)) { | ||
messages = parsedOutput.messages as Message[]; | ||
} | ||
|
||
if ( | ||
hasLessonPlan(parsedOutput) && | ||
isValidLessonPlan(parsedOutput.lessonPlan) | ||
) { | ||
return parsedOutput.lessonPlan as LooseLessonPlan; | ||
lessonPlan = parsedOutput.lessonPlan as LooseLessonPlan; | ||
} | ||
} catch (error) { | ||
log.error("Error parsing output to extract lesson plan", error); | ||
log.error(`Error parsing output for chat ${chatId}`, error); | ||
captureException(error, { | ||
extra: { chatId, output }, | ||
tags: { context: "parseChatOutput" }, | ||
}); | ||
} | ||
|
||
return {}; | ||
return { messages, lessonPlan }; | ||
} | ||
|
||
async function loadLessonPlanFromDatabase( | ||
async function loadChatDataFromDatabase( | ||
chatId: string, | ||
userId: string, | ||
): Promise<LooseLessonPlan> { | ||
): Promise<{ messages: Message[]; lessonPlan: LooseLessonPlan }> { | ||
try { | ||
const chat = await prisma.appSession.findUnique({ | ||
where: { id: chatId }, | ||
|
@@ -214,33 +249,113 @@ async function loadLessonPlanFromDatabase( | |
|
||
if (!chat) { | ||
log.info(`No existing chat found for id: ${chatId}`); | ||
return {}; | ||
return { messages: [], lessonPlan: {} }; | ||
} | ||
|
||
if (chat.userId !== userId) { | ||
log.error( | ||
`User ${userId} attempted to access chat ${chatId} which belongs to ${chat.userId}`, | ||
); | ||
throw new Error("Unauthorized access to chat"); | ||
} | ||
verifyChatOwnership(chat, userId, chatId); | ||
|
||
const { messages, lessonPlan } = parseChatOutput(chat.output, chatId); | ||
|
||
const lessonPlan = parseLessonPlanFromOutput(chat.output); | ||
log.info(`Loaded lesson plan for chat ${chatId}`); | ||
return lessonPlan; | ||
log.info( | ||
`Loaded ${messages.length} messages and lesson plan for chat ${chatId}`, | ||
); | ||
return { messages, lessonPlan }; | ||
} catch (error) { | ||
log.error(`Error loading lesson plan for chat ${chatId}`, error); | ||
return {}; | ||
log.error(`Error loading chat data for chat ${chatId}`, error); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Again here, just want to confirm that we don't want to report to sentry |
||
captureException(error, { | ||
extra: { chatId, userId }, | ||
tags: { context: "loadChatDataFromDatabase" }, | ||
}); | ||
throw error; | ||
} | ||
} | ||
|
||
function extractLatestUserMessage(frontendMessages: Message[]): Message | null { | ||
return (frontendMessages ?? []).findLast((m) => m?.role === "user") ?? null; | ||
} | ||
|
||
function prepareMessages( | ||
dbMessages: Message[], | ||
frontendMessages: Message[], | ||
chatId: string, | ||
): Message[] { | ||
const latestUserMessage = extractLatestUserMessage(frontendMessages); | ||
|
||
let messages = [...dbMessages]; | ||
if ( | ||
latestUserMessage && | ||
!messages.some((m) => m.id === latestUserMessage.id) | ||
) { | ||
messages.push(latestUserMessage); | ||
log.info(`Appended new user message to history for chat ${chatId}`); | ||
} | ||
|
||
return messages; | ||
} | ||
|
||
type CreateAilaInstanceArguments = { | ||
config: Config; | ||
options: AilaOptions; | ||
chatId: string; | ||
userId: string | undefined; | ||
messages: Message[]; | ||
lessonPlan: LooseLessonPlan; | ||
llmService: ReturnType<typeof getFixtureLLMService>; | ||
moderationAiClient: ReturnType<typeof getFixtureModerationOpenAiClient>; | ||
threatDetectors: AilaThreatDetector[]; | ||
}; | ||
|
||
async function createAilaInstance({ | ||
config, | ||
options, | ||
chatId, | ||
userId, | ||
messages, | ||
lessonPlan, | ||
llmService, | ||
moderationAiClient, | ||
threatDetectors, | ||
}: CreateAilaInstanceArguments): Promise<Aila> { | ||
return await withTelemetry( | ||
"chat-create-aila", | ||
{ chat_id: chatId, user_id: userId }, | ||
async (): Promise<Aila> => { | ||
const ailaOptions: Partial<AilaInitializationOptions> = { | ||
options, | ||
chat: { | ||
id: chatId, | ||
userId, | ||
messages, | ||
}, | ||
services: { | ||
chatLlmService: llmService, | ||
moderationAiClient, | ||
ragService: (aila: AilaServices) => new AilaRag({ aila }), | ||
americanismsService: () => new AilaAmericanisms(), | ||
analyticsAdapters: (aila: AilaServices) => [ | ||
new PosthogAnalyticsAdapter(aila), | ||
new DatadogAnalyticsAdapter(aila), | ||
], | ||
threatDetectors: () => threatDetectors, | ||
}, | ||
document: { | ||
content: lessonPlan ?? {}, | ||
}, | ||
}; | ||
const result = await config.createAila(ailaOptions); | ||
return result; | ||
}, | ||
); | ||
} | ||
|
||
export async function handleChatPostRequest( | ||
req: NextRequest, | ||
config: Config, | ||
): Promise<Response> { | ||
return await withTelemetry("chat-api", {}, async (span: TracingSpan) => { | ||
const { | ||
chatId, | ||
messages, | ||
messages: frontendMessages, | ||
options, | ||
llmService, | ||
moderationAiClient, | ||
|
@@ -254,7 +369,10 @@ export async function handleChatPostRequest( | |
userId = await fetchAndCheckUser(chatId); | ||
span.setTag("user_id", userId); | ||
|
||
const dbLessonPlan = await loadLessonPlanFromDatabase(chatId, userId); | ||
const { messages: dbMessages, lessonPlan: dbLessonPlan } = | ||
await loadChatDataFromDatabase(chatId, userId); | ||
|
||
const messages = prepareMessages(dbMessages, frontendMessages, chatId); | ||
|
||
setTelemetryMetadata({ | ||
span, | ||
|
@@ -264,38 +382,17 @@ export async function handleChatPostRequest( | |
options, | ||
}); | ||
|
||
aila = await withTelemetry( | ||
"chat-create-aila", | ||
{ chat_id: chatId, user_id: userId }, | ||
async (): Promise<Aila> => { | ||
const ailaOptions: Partial<AilaInitializationOptions> = { | ||
options, | ||
chat: { | ||
id: chatId, | ||
userId, | ||
messages, | ||
}, | ||
services: { | ||
chatLlmService: llmService, | ||
moderationAiClient, | ||
ragService: (aila: AilaServices) => new AilaRag({ aila }), | ||
americanismsService: () => | ||
new AilaAmericanisms<LooseLessonPlan>(), | ||
analyticsAdapters: (aila: AilaServices) => [ | ||
new PosthogAnalyticsAdapter(aila), | ||
new DatadogAnalyticsAdapter(aila), | ||
], | ||
threatDetectors: () => threatDetectors, | ||
}, | ||
|
||
document: { | ||
content: dbLessonPlan ?? {}, | ||
}, | ||
}; | ||
const result = await config.createAila(ailaOptions); | ||
return result; | ||
}, | ||
); | ||
aila = await createAilaInstance({ | ||
config, | ||
options, | ||
chatId, | ||
userId, | ||
messages, | ||
lessonPlan: dbLessonPlan, | ||
llmService, | ||
moderationAiClient, | ||
threatDetectors, | ||
}); | ||
Comment on lines
+385
to
+395
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Much nicer! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Cheers! |
||
invariant(aila, "Aila instance is required"); | ||
|
||
const abortController = handleConnectionAborted(req); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just want to confirm that we don't want to report to sentry here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks - I've added Sentry reporting. For now, no additional error handling for the front end though