From 4aee9a533bf287281f9ec8e8d8b44efdff5a6f71 Mon Sep 17 00:00:00 2001 From: ahaapple Date: Thu, 9 Jan 2025 23:15:40 +0800 Subject: [PATCH] Support message history compression --- frontend/app/api/search/route.ts | 34 ++++++------- frontend/components/search/search-bar.tsx | 9 +++- frontend/components/search/search-window.tsx | 23 ++++++++- frontend/hooks/use-compress-history.ts | 46 ++++++++++++++++++ frontend/lib/llm/llm.ts | 4 ++ frontend/lib/llm/utils.ts | 25 +++++++++- frontend/lib/store/local-store.ts | 18 +++++++ frontend/lib/tools/auto.ts | 3 +- frontend/lib/tools/chat.ts | 4 +- frontend/lib/tools/compress-history.ts | 50 ++++++++++++++++++++ frontend/lib/tools/generate-ui.ts | 6 +-- frontend/lib/tools/indie.ts | 6 +-- frontend/lib/tools/knowledge-base.ts | 6 +-- frontend/lib/tools/product.ts | 6 +-- frontend/lib/types.ts | 2 + 15 files changed, 204 insertions(+), 38 deletions(-) create mode 100644 frontend/hooks/use-compress-history.ts create mode 100644 frontend/lib/tools/compress-history.ts diff --git a/frontend/app/api/search/route.ts b/frontend/app/api/search/route.ts index 7a852a28..73043476 100644 --- a/frontend/app/api/search/route.ts +++ b/frontend/app/api/search/route.ts @@ -42,24 +42,9 @@ export async function POST(req: NextRequest) { const userId = session?.user?.id ?? ''; const isPro = session?.user ? isProUser(session.user) : false; - let { model, source, messages, profile, isSearch, questionLanguage, answerLanguage } = await req.json(); + let { model, source, messages, profile, isSearch, questionLanguage, answerLanguage, summary } = await req.json(); - console.log( - 'model', - model, - 'source', - source, - 'messages', - messages, - 'userId', - userId, - 'isSearch', - isSearch, - 'questionLanguage', - questionLanguage, - 'answerLanguage', - answerLanguage, - ); + console.log('model', model, 'source', source, 'messages', messages.length, 'userId', userId, 'isSearch', isSearch, 'summary', summary); if (isProModel(model) && !isPro) { return NextResponse.json( @@ -90,7 +75,7 @@ export async function POST(req: NextRequest) { break; } case SearchCategory.CHAT: { - await chat(messages, isPro, userId, profile, streamController(controller), answerLanguage, model); + await chat(messages, isPro, userId, profile, summary, streamController(controller), answerLanguage, model); break; } case SearchCategory.PRODUCT_HUNT: { @@ -106,7 +91,18 @@ export async function POST(req: NextRequest) { break; } default: { - await autoAnswer(messages, isPro, userId, profile, streamController(controller), questionLanguage, answerLanguage, model, source); + await autoAnswer( + messages, + isPro, + userId, + profile, + summary, + streamController(controller), + questionLanguage, + answerLanguage, + model, + source, + ); } } }, diff --git a/frontend/components/search/search-bar.tsx b/frontend/components/search/search-bar.tsx index 0c11dc7a..1d294ca0 100644 --- a/frontend/components/search/search-bar.tsx +++ b/frontend/components/search/search-bar.tsx @@ -23,6 +23,7 @@ import { SearchType } from '@/lib/types'; import WebImageModal, { WebImageFile } from '@/components/modal/web-images-model'; import { isImageInputModel } from '@/lib/model'; import { SearchSettingsDialog } from '@/components/search/search-settings'; +import { useCompressHistory } from '@/hooks/use-compress-history'; interface Props { handleSearch: (key: string, attachments?: string[]) => void; @@ -122,6 +123,12 @@ const SearchBar: React.FC = ({ } }; + const { compressHistoryMessages } = useCompressHistory(); + const handleContentChange = (e: React.ChangeEvent) => { + setContent(e.target.value); + compressHistoryMessages(); + }; + const [files, setFiles] = useState([]); const dropzoneRef = useRef(null); const { onUpload, uploadedFiles, setUploadedFiles, isUploading } = useUploadFile(); @@ -268,7 +275,7 @@ const SearchBar: React.FC = ({ aria-label="Search" className="w-full border-0 bg-transparent p-4 mb-8 text-sm placeholder:text-muted-foreground overflow-y-auto outline-0 ring-0 focus-visible:outline-none focus-visible:ring-0 resize-none" onKeyDown={handleInputKeydown} - onChange={(e) => setContent(e.target.value)} + onChange={handleContentChange} >
diff --git a/frontend/components/search/search-window.tsx b/frontend/components/search/search-window.tsx index 9092f8df..a0f9aee9 100644 --- a/frontend/components/search/search-window.tsx +++ b/frontend/components/search/search-window.tsx @@ -6,7 +6,7 @@ import SearchMessage from '@/components/search/search-message'; import { fetchEventSource } from '@microsoft/fetch-event-source'; import { useSearchParams } from 'next/navigation'; import { useSigninModal } from '@/hooks/use-signin-modal'; -import { useConfigStore, useProfileStore, useUIStore } from '@/lib/store/local-store'; +import { useConfigStore, useProfileStore, useSearchState, useUIStore } from '@/lib/store/local-store'; import { ImageSource, Message, SearchType, TextSource, User, VideoSource } from '@/lib/types'; import { LoaderCircle } from 'lucide-react'; @@ -66,6 +66,7 @@ export default function SearchWindow({ id, initialMessages, user, isReadOnly = f // monitorMemoryUsage(); const { incrementSearchCount, canSearch } = useSearchLimit(); + const { isCompressHistory } = useSearchState(); const sendMessage = useCallback( async (question?: string, attachments?: string[], messageIdToUpdate?: string) => { @@ -119,6 +120,24 @@ export default function SearchWindow({ id, initialMessages, user, isReadOnly = f if (!messageValue && attachments && searchType === 'ui') { messageValue = 'Please generate the same UI as the image'; } + + const waitForCompression = async () => { + if (!isCompressHistory) return; + + return new Promise((resolve) => { + const checkCompressionStatus = () => { + if (!isCompressHistory) { + resolve(); + } else { + console.log('Waiting for compression to finish...'); + setTimeout(checkCompressionStatus, 100); + } + }; + checkCompressionStatus(); + }); + }; + await waitForCompression(); + // const imageUrls = extractAllImageUrls(messageValue); // if (imageUrls.length > 1 && user && !isProUser(user)) { // toast.error(t('multi-image-free-limit')); @@ -200,6 +219,7 @@ export default function SearchWindow({ id, initialMessages, user, isReadOnly = f title: title, createdAt: new Date(), userId: user?.id, + lastCompressIndex: 0, messages: [ { id: activeId, @@ -244,6 +264,7 @@ export default function SearchWindow({ id, initialMessages, user, isReadOnly = f isSearch: useUIStore.getState().isSearch, isShadcnUI: useUIStore.getState().isShadcnUI, messages: useSearchStore.getState().activeSearch.messages, + summary: useSearchStore.getState().activeSearch.summary, }), openWhenHidden: true, onerror(err) { diff --git a/frontend/hooks/use-compress-history.ts b/frontend/hooks/use-compress-history.ts new file mode 100644 index 00000000..5631f49d --- /dev/null +++ b/frontend/hooks/use-compress-history.ts @@ -0,0 +1,46 @@ +import { isProModel } from '@/lib/model'; +import { useSearchStore } from '@/lib/store/local-history'; +import { useConfigStore, useSearchState } from '@/lib/store/local-store'; +import { compressHistory } from '@/lib/tools/compress-history'; + +export function useCompressHistory() { + const { activeSearch, updateActiveSearch } = useSearchStore(); + const { isCompressHistory, setIsCompressHistory } = useSearchState(); + + const compressHistoryMessages = async () => { + console.log('compressHistoryMessages'); + if (isCompressHistory) return; + if (!activeSearch?.messages) return; + + const messages = activeSearch.messages; + const totalMessages = messages.length; + if (totalMessages < 4) return; + + const model = useConfigStore.getState().model; + if (!isProModel(model)) { + return; + } + + console.log('compressHistoryMessages totalMessages:', totalMessages); + + try { + const compressIndex = activeSearch.lastCompressIndex || 0; + const newMessagesToCompress = messages.slice(compressIndex); + console.log('compressHistoryMessages newMessagesToCompress:', newMessagesToCompress, compressIndex); + if (newMessagesToCompress.length < 4 || newMessagesToCompress.length > 6) { + return; + } + setIsCompressHistory(true); + const newSummary = await compressHistory(newMessagesToCompress, activeSearch.summary); + if (newSummary.length > 0) { + const newCompressIndex = totalMessages; + updateActiveSearch({ summary: newSummary, lastCompressIndex: newCompressIndex }); + } + } catch (error) { + console.error('Failed to compress history:', error); + } finally { + setIsCompressHistory(false); + } + }; + return { compressHistoryMessages }; +} diff --git a/frontend/lib/llm/llm.ts b/frontend/lib/llm/llm.ts index a0bec826..60c22bcd 100644 --- a/frontend/lib/llm/llm.ts +++ b/frontend/lib/llm/llm.ts @@ -69,6 +69,10 @@ export function convertToCoreMessages(messages: Message[]): CoreMessage[] { coreMessages.push({ role: 'assistant', content: message.content }); break; } + case 'system': { + coreMessages.push({ role: 'system', content: message.content }); + break; + } default: { throw new Error(`Unhandled role: ${message.role}`); } diff --git a/frontend/lib/llm/utils.ts b/frontend/lib/llm/utils.ts index 20adc48d..8f5a6379 100644 --- a/frontend/lib/llm/utils.ts +++ b/frontend/lib/llm/utils.ts @@ -1,10 +1,31 @@ +import { Message } from '@/lib/types'; import 'server-only'; -export function getHistoryMessages(isPro: boolean, messages: any[]) { +export function getHistoryMessages(isPro: boolean, messages: any[], summary?: string) { const sliceNum = isPro ? -7 : -3; - return messages?.slice(sliceNum); + const slicedMessages = messages?.slice(sliceNum); + if (summary) { + return [ + { + content: summary, + role: 'system', + }, + ...slicedMessages.slice(-2), + ]; + } + return slicedMessages; } +const formatMessage = (message: Message) => { + return `<${message.role}>${message.content}`; +}; + +export const formatHistoryMessages = (messages: Message[]) => { + return ` + ${messages.map((m) => formatMessage(m)).join('\n')} + `; +}; + export function getHistory(isPro: boolean, messages: any[]) { const sliceNum = isPro ? -7 : -3; return messages diff --git a/frontend/lib/store/local-store.ts b/frontend/lib/store/local-store.ts index d291f01d..8a635c32 100644 --- a/frontend/lib/store/local-store.ts +++ b/frontend/lib/store/local-store.ts @@ -116,3 +116,21 @@ export const useIndexStore = create()( }, ), ); + +interface SearchState { + isTyping: boolean; + isCompressHistory: boolean; + isSearching: boolean; + setIsSearching: (status: boolean) => void; + setIsTyping: (status: boolean) => void; + setIsCompressHistory: (status: boolean) => void; +} + +export const useSearchState = create()((set) => ({ + isTyping: false, + isCompressHistory: false, + isSearching: false, + setIsSearching: (status: boolean) => set({ isSearching: status }), + setIsTyping: (status: boolean) => set({ isTyping: status }), + setIsCompressHistory: (status: boolean) => set({ isCompressHistory: status }), +})); diff --git a/frontend/lib/tools/auto.ts b/frontend/lib/tools/auto.ts index ae31a37c..5f2f30d4 100644 --- a/frontend/lib/tools/auto.ts +++ b/frontend/lib/tools/auto.ts @@ -62,6 +62,7 @@ export async function autoAnswer( isPro: boolean, userId: string, profile?: string, + summary?: string, onStream?: (...args: any[]) => void, questionLanguage?: string, answerLanguage?: string, @@ -69,7 +70,7 @@ export async function autoAnswer( source = SearchCategory.ALL, ) { try { - const newMessages = getHistoryMessages(isPro, messages); + const newMessages = getHistoryMessages(isPro, messages, summary); const query = newMessages[newMessages.length - 1].content; let texts: TextSource[] = []; diff --git a/frontend/lib/tools/chat.ts b/frontend/lib/tools/chat.ts index d6a4d4fb..52aa0c76 100644 --- a/frontend/lib/tools/chat.ts +++ b/frontend/lib/tools/chat.ts @@ -21,15 +21,15 @@ export async function chat( isPro: boolean, userId: string, profile?: string, + summary?: string, onStream?: (...args: any[]) => void, answerLanguage?: string, model = GPT_4o_MIMI, ) { try { - const newMessages = getHistoryMessages(isPro, messages); + const newMessages = getHistoryMessages(isPro, messages, summary); const query = newMessages[newMessages.length - 1].content; - // console.log('answerLanguage', answerLanguage); let languageInstructions = ''; if (answerLanguage !== 'auto') { languageInstructions = util.format(UserLanguagePrompt, answerLanguage); diff --git a/frontend/lib/tools/compress-history.ts b/frontend/lib/tools/compress-history.ts new file mode 100644 index 00000000..09d941db --- /dev/null +++ b/frontend/lib/tools/compress-history.ts @@ -0,0 +1,50 @@ +'use server'; + +import { getLLM } from '@/lib/llm/llm'; +import { formatHistoryMessages } from '@/lib/llm/utils'; +import { GPT_4o_MIMI } from '@/lib/model'; +import { Message } from '@/lib/types'; +import { generateText } from 'ai'; + +export async function compressHistory(messages: Message[], previousSummary?: string): Promise { + if (messages.length < 4) { + return ''; + } + + try { + console.log('compressHistory messages:', messages); + console.log('compressHistory previousSummary:', previousSummary); + console.time('compressHistory'); + + const systemPrompt = `You're an assistant who's good at extracting key takeaways from conversations and summarizing them. Please summarize according to the user's needs. ${ + previousSummary + ? 'Please incorporate the previous summary with new messages to create an updated comprehensive summary.' + : 'Create a new summary from the messages.' + }`; + + const userPrompt = previousSummary + ? `Previous Summary: ${previousSummary}\n\nNew Messages: ${formatHistoryMessages(messages)}\n\nPlease create an updated summary incorporating both the previous summary and new messages. Limit to 400 tokens.` + : `${formatHistoryMessages(messages)}\nPlease summarize the above conversation and retain key information. Limit to 400 tokens.`; + + const { text } = await generateText({ + model: getLLM(GPT_4o_MIMI), + messages: [ + { + content: systemPrompt, + role: 'system', + }, + { + content: userPrompt, + role: 'user', + }, + ], + }); + + console.timeEnd('compressHistory'); + console.log('compressHistory text:', text); + return text; + } catch (error) { + console.error('Error compress history:', error); + return previousSummary || ''; + } +} diff --git a/frontend/lib/tools/generate-ui.ts b/frontend/lib/tools/generate-ui.ts index df4be6ed..4e97220c 100644 --- a/frontend/lib/tools/generate-ui.ts +++ b/frontend/lib/tools/generate-ui.ts @@ -107,9 +107,9 @@ export async function generateUI( onStream?.(JSON.stringify({ answer: text })); } - incSearchCount(userId).catch((error) => { - console.error(`Failed to increment search count for user ${userId}:`, error); - }); + // incSearchCount(userId).catch((error) => { + // console.error(`Failed to increment search count for user ${userId}:`, error); + // }); await saveMessages(userId, messages, fullAnswer, [], [], [], '', SearchCategory.UI); } catch (error) { diff --git a/frontend/lib/tools/indie.ts b/frontend/lib/tools/indie.ts index f1d12fe0..c2c1bd29 100644 --- a/frontend/lib/tools/indie.ts +++ b/frontend/lib/tools/indie.ts @@ -102,9 +102,9 @@ export async function indieMakerSearch( ); }); - incSearchCount(userId).catch((error) => { - console.error(`Failed to increment search count for user ${userId}:`, error); - }); + // incSearchCount(userId).catch((error) => { + // console.error(`Failed to increment search count for user ${userId}:`, error); + // }); await saveMessages(userId, messages, fullAnswer, texts, images, videos, fullRelated); onStream?.(null, true); diff --git a/frontend/lib/tools/knowledge-base.ts b/frontend/lib/tools/knowledge-base.ts index cbc32232..64644428 100644 --- a/frontend/lib/tools/knowledge-base.ts +++ b/frontend/lib/tools/knowledge-base.ts @@ -66,9 +66,9 @@ export async function knowledgeBaseSearch(messages: StoreMessage[], isPro: boole return; } - incSearchCount(userId).catch((error) => { - console.error(`Failed to increment search count for user ${userId}:`, error); - }); + // incSearchCount(userId).catch((error) => { + // console.error(`Failed to increment search count for user ${userId}:`, error); + // }); await saveMessages(userId, messages, fullAnswer, texts, [], [], ''); onStream?.(null, true); diff --git a/frontend/lib/tools/product.ts b/frontend/lib/tools/product.ts index 5c0dda68..80bfc9a5 100644 --- a/frontend/lib/tools/product.ts +++ b/frontend/lib/tools/product.ts @@ -99,9 +99,9 @@ export async function productSearch( ); }); - incSearchCount(userId).catch((error) => { - console.error(`Failed to increment search count for user ${userId}:`, error); - }); + // incSearchCount(userId).catch((error) => { + // console.error(`Failed to increment search count for user ${userId}:`, error); + // }); await saveMessages(userId, messages, fullAnswer, texts, images, videos, fullRelated); onStream?.(null, true); diff --git a/frontend/lib/types.ts b/frontend/lib/types.ts index 2800fac3..8aae2b58 100644 --- a/frontend/lib/types.ts +++ b/frontend/lib/types.ts @@ -95,6 +95,8 @@ export interface Search extends Record { userId: string; messages: Message[]; sharePath?: string; + summary?: string; + lastCompressIndex?: number; } export interface GenImage extends Record {