Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: do not trust front end message history #570

Merged
merged 11 commits into from
Mar 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
203 changes: 150 additions & 53 deletions apps/nextjs/src/app/api/chat/chatHandler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import {
PosthogAnalyticsAdapter,
} from "@oakai/aila/src/features/analytics";
import { AilaRag } from "@oakai/aila/src/features/rag/AilaRag";
import type { AilaThreatDetector } from "@oakai/aila/src/features/threatDetection";
import { HeliconeThreatDetector } from "@oakai/aila/src/features/threatDetection/detectors/helicone/HeliconeThreatDetector";
import { LakeraThreatDetector } from "@oakai/aila/src/features/threatDetection/detectors/lakera/LakeraThreatDetector";
import type { LooseLessonPlan } from "@oakai/aila/src/protocol/schema";
Expand All @@ -20,6 +21,7 @@ import { withTelemetry } from "@oakai/core/src/tracing/serverTracing";
import type { PrismaClientWithAccelerate } from "@oakai/db";
import { prisma as globalPrisma } from "@oakai/db/client";
import { aiLogger } from "@oakai/logger";
import { captureException } from "@sentry/nextjs";
import type { NextRequest } from "next/server";
import invariant from "tiny-invariant";

Expand Down Expand Up @@ -178,30 +180,63 @@ function isValidLessonPlan(lessonPlan: unknown): boolean {
return lessonPlan !== null && typeof lessonPlan === "object";
}

function parseLessonPlanFromOutput(output: unknown): LooseLessonPlan {
if (!output) return {};
function hasMessages(obj: unknown): obj is { messages: unknown } {
return obj !== null && typeof obj === "object" && "messages" in obj;
}

function isValidMessages(messages: unknown): boolean {
return Array.isArray(messages);
}

function verifyChatOwnership(
chat: { userId: string },
requestUserId: string,
chatId: string,
): void {
if (chat.userId !== requestUserId) {
log.error(
`User ${requestUserId} attempted to access chat ${chatId} which belongs to ${chat.userId}`,
);
throw new Error("Unauthorized access to chat");
}
}

function parseChatOutput(
output: unknown,
chatId: string,
): { messages: Message[]; lessonPlan: LooseLessonPlan } {
let messages: Message[] = [];
let lessonPlan: LooseLessonPlan = {};

try {
const parsedOutput =
typeof output === "string" ? JSON.parse(output) : output;

if (hasMessages(parsedOutput) && isValidMessages(parsedOutput.messages)) {
messages = parsedOutput.messages as Message[];
}

if (
hasLessonPlan(parsedOutput) &&
isValidLessonPlan(parsedOutput.lessonPlan)
) {
return parsedOutput.lessonPlan as LooseLessonPlan;
lessonPlan = parsedOutput.lessonPlan as LooseLessonPlan;
}
} catch (error) {
log.error("Error parsing output to extract lesson plan", error);
log.error(`Error parsing output for chat ${chatId}`, error);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just want to confirm that we don't want to report to sentry here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks - I've added Sentry reporting. For now, no additional error handling for the front end though

captureException(error, {
extra: { chatId, output },
tags: { context: "parseChatOutput" },
});
}

return {};
return { messages, lessonPlan };
}

async function loadLessonPlanFromDatabase(
async function loadChatDataFromDatabase(
chatId: string,
userId: string,
): Promise<LooseLessonPlan> {
): Promise<{ messages: Message[]; lessonPlan: LooseLessonPlan }> {
try {
const chat = await prisma.appSession.findUnique({
where: { id: chatId },
Expand All @@ -214,33 +249,113 @@ async function loadLessonPlanFromDatabase(

if (!chat) {
log.info(`No existing chat found for id: ${chatId}`);
return {};
return { messages: [], lessonPlan: {} };
}

if (chat.userId !== userId) {
log.error(
`User ${userId} attempted to access chat ${chatId} which belongs to ${chat.userId}`,
);
throw new Error("Unauthorized access to chat");
}
verifyChatOwnership(chat, userId, chatId);

const { messages, lessonPlan } = parseChatOutput(chat.output, chatId);

const lessonPlan = parseLessonPlanFromOutput(chat.output);
log.info(`Loaded lesson plan for chat ${chatId}`);
return lessonPlan;
log.info(
`Loaded ${messages.length} messages and lesson plan for chat ${chatId}`,
);
return { messages, lessonPlan };
} catch (error) {
log.error(`Error loading lesson plan for chat ${chatId}`, error);
return {};
log.error(`Error loading chat data for chat ${chatId}`, error);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again here, just want to confirm that we don't want to report to sentry

captureException(error, {
extra: { chatId, userId },
tags: { context: "loadChatDataFromDatabase" },
});
throw error;
}
}

function extractLatestUserMessage(frontendMessages: Message[]): Message | null {
return (frontendMessages ?? []).findLast((m) => m?.role === "user") ?? null;
}

function prepareMessages(
dbMessages: Message[],
frontendMessages: Message[],
chatId: string,
): Message[] {
const latestUserMessage = extractLatestUserMessage(frontendMessages);

let messages = [...dbMessages];
if (
latestUserMessage &&
!messages.some((m) => m.id === latestUserMessage.id)
) {
messages.push(latestUserMessage);
log.info(`Appended new user message to history for chat ${chatId}`);
}

return messages;
}

type CreateAilaInstanceArguments = {
config: Config;
options: AilaOptions;
chatId: string;
userId: string | undefined;
messages: Message[];
lessonPlan: LooseLessonPlan;
llmService: ReturnType<typeof getFixtureLLMService>;
moderationAiClient: ReturnType<typeof getFixtureModerationOpenAiClient>;
threatDetectors: AilaThreatDetector[];
};

async function createAilaInstance({
config,
options,
chatId,
userId,
messages,
lessonPlan,
llmService,
moderationAiClient,
threatDetectors,
}: CreateAilaInstanceArguments): Promise<Aila> {
return await withTelemetry(
"chat-create-aila",
{ chat_id: chatId, user_id: userId },
async (): Promise<Aila> => {
const ailaOptions: Partial<AilaInitializationOptions> = {
options,
chat: {
id: chatId,
userId,
messages,
},
services: {
chatLlmService: llmService,
moderationAiClient,
ragService: (aila: AilaServices) => new AilaRag({ aila }),
americanismsService: () => new AilaAmericanisms(),
analyticsAdapters: (aila: AilaServices) => [
new PosthogAnalyticsAdapter(aila),
new DatadogAnalyticsAdapter(aila),
],
threatDetectors: () => threatDetectors,
},
document: {
content: lessonPlan ?? {},
},
};
const result = await config.createAila(ailaOptions);
return result;
},
);
}

export async function handleChatPostRequest(
req: NextRequest,
config: Config,
): Promise<Response> {
return await withTelemetry("chat-api", {}, async (span: TracingSpan) => {
const {
chatId,
messages,
messages: frontendMessages,
options,
llmService,
moderationAiClient,
Expand All @@ -254,7 +369,10 @@ export async function handleChatPostRequest(
userId = await fetchAndCheckUser(chatId);
span.setTag("user_id", userId);

const dbLessonPlan = await loadLessonPlanFromDatabase(chatId, userId);
const { messages: dbMessages, lessonPlan: dbLessonPlan } =
await loadChatDataFromDatabase(chatId, userId);

const messages = prepareMessages(dbMessages, frontendMessages, chatId);

setTelemetryMetadata({
span,
Expand All @@ -264,38 +382,17 @@ export async function handleChatPostRequest(
options,
});

aila = await withTelemetry(
"chat-create-aila",
{ chat_id: chatId, user_id: userId },
async (): Promise<Aila> => {
const ailaOptions: Partial<AilaInitializationOptions> = {
options,
chat: {
id: chatId,
userId,
messages,
},
services: {
chatLlmService: llmService,
moderationAiClient,
ragService: (aila: AilaServices) => new AilaRag({ aila }),
americanismsService: () =>
new AilaAmericanisms<LooseLessonPlan>(),
analyticsAdapters: (aila: AilaServices) => [
new PosthogAnalyticsAdapter(aila),
new DatadogAnalyticsAdapter(aila),
],
threatDetectors: () => threatDetectors,
},

document: {
content: dbLessonPlan ?? {},
},
};
const result = await config.createAila(ailaOptions);
return result;
},
);
aila = await createAilaInstance({
config,
options,
chatId,
userId,
messages,
lessonPlan: dbLessonPlan,
llmService,
moderationAiClient,
threatDetectors,
});
Comment on lines +385 to +395
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Much nicer!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cheers!

invariant(aila, "Aila instance is required");

const abortController = handleConnectionAborted(req);
Expand Down
1 change: 0 additions & 1 deletion apps/nextjs/src/components/AppComponents/Chat/AiSdk.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ export function AiSdk({ id }: Readonly<AiSdkProps>) {
const [hasFinished, setHasFinished] = useState(true);

const initialMessages = useChatStore((state) => state.initialMessages);
const lessonPlan = useLessonPlanStore((state) => state.lessonPlan);
const streamingFinished = useChatStore((state) => state.streamingFinished);
const scrollToBottom = useChatStore((state) => state.scrollToBottom);
const messageStarted = useLessonPlanStore((state) => state.messageStarted);
Expand Down
Empty file.
Empty file.
Loading