From 2cfc16709d9af1f1b9a7d094f596e0c88bba0179 Mon Sep 17 00:00:00 2001 From: Dario Gieselaar Date: Tue, 7 Jan 2025 12:17:56 +0100 Subject: [PATCH] [Inference] Image content (#205371) Adds support for image content parts in the Inference plugin. Only base64 encoded images are supported, as this capability is shared across all three LLM providers. --------- Co-authored-by: Elastic Machine --- .../shared/ai-infra/inference-common/index.ts | 3 + .../src/chat_complete/index.ts | 3 + .../src/chat_complete/messages.ts | 16 +++- .../bedrock/bedrock_claude_adapter.test.ts | 86 +++++++++++++++++++ .../bedrock/bedrock_claude_adapter.ts | 21 ++++- .../chat_complete/adapters/bedrock/types.ts | 39 +++++++-- .../adapters/gemini/gemini_adapter.test.ts | 80 +++++++++++++++++ .../adapters/gemini/gemini_adapter.ts | 20 +++-- .../adapters/openai/openai_adapter.test.ts | 80 +++++++++++++++++ .../adapters/openai/to_openai.ts | 20 ++++- .../wrap_with_simulated_function_calling.ts | 18 +++- 11 files changed, 367 insertions(+), 19 deletions(-) diff --git a/x-pack/platform/packages/shared/ai-infra/inference-common/index.ts b/x-pack/platform/packages/shared/ai-infra/inference-common/index.ts index 0c6d254c0f527..b466d6ac6879b 100644 --- a/x-pack/platform/packages/shared/ai-infra/inference-common/index.ts +++ b/x-pack/platform/packages/shared/ai-infra/inference-common/index.ts @@ -10,6 +10,9 @@ export { ChatCompletionEventType, ToolChoiceType, type Message, + type MessageContentImage, + type MessageContentText, + type MessageContent, type AssistantMessage, type ToolMessage, type UserMessage, diff --git a/x-pack/platform/packages/shared/ai-infra/inference-common/src/chat_complete/index.ts b/x-pack/platform/packages/shared/ai-infra/inference-common/src/chat_complete/index.ts index cedc8297d75bc..227e72d93ca92 100644 --- a/x-pack/platform/packages/shared/ai-infra/inference-common/src/chat_complete/index.ts +++ b/x-pack/platform/packages/shared/ai-infra/inference-common/src/chat_complete/index.ts @@ -29,6 +29,9 @@ export { } from './events'; export { MessageRole, + type MessageContent, + type MessageContentImage, + type MessageContentText, type Message, type AssistantMessage, type UserMessage, diff --git a/x-pack/platform/packages/shared/ai-infra/inference-common/src/chat_complete/messages.ts b/x-pack/platform/packages/shared/ai-infra/inference-common/src/chat_complete/messages.ts index 43d03cf130c01..54b9b76d2bd8c 100644 --- a/x-pack/platform/packages/shared/ai-infra/inference-common/src/chat_complete/messages.ts +++ b/x-pack/platform/packages/shared/ai-infra/inference-common/src/chat_complete/messages.ts @@ -23,14 +23,26 @@ interface MessageBase { role: TRole; } +export interface MessageContentText { + type: 'text'; + text: string; +} + +export interface MessageContentImage { + type: 'image'; + source: { data: string; mimeType: string }; +} + +export type MessageContent = string | Array; + /** * Represents a message from the user. */ export type UserMessage = MessageBase & { /** - * The text content of the user message + * The text or image content of the user message */ - content: string; + content: MessageContent; }; /** diff --git a/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/bedrock/bedrock_claude_adapter.test.ts b/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/bedrock/bedrock_claude_adapter.test.ts index c6114c3b09e95..b0a5bcbc71bca 100644 --- a/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/bedrock/bedrock_claude_adapter.test.ts +++ b/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/bedrock/bedrock_claude_adapter.test.ts @@ -256,6 +256,92 @@ describe('bedrockClaudeAdapter', () => { expect(system).toEqual('Some system message'); }); + it('correctly formats messages with content parts', () => { + bedrockClaudeAdapter.chatComplete({ + executor: executorMock, + logger, + messages: [ + { + role: MessageRole.User, + content: [ + { + type: 'text', + text: 'question', + }, + ], + }, + { + role: MessageRole.Assistant, + content: 'answer', + }, + { + role: MessageRole.User, + content: [ + { + type: 'image', + source: { + data: 'aaaaaa', + mimeType: 'image/png', + }, + }, + { + type: 'image', + source: { + data: 'bbbbbb', + mimeType: 'image/png', + }, + }, + ], + }, + ], + }); + + expect(executorMock.invoke).toHaveBeenCalledTimes(1); + + const { messages } = getCallParams(); + expect(messages).toEqual([ + { + rawContent: [ + { + text: 'question', + type: 'text', + }, + ], + role: 'user', + }, + { + rawContent: [ + { + text: 'answer', + type: 'text', + }, + ], + role: 'assistant', + }, + { + rawContent: [ + { + type: 'image', + source: { + data: 'aaaaaa', + mediaType: 'image/png', + type: 'base64', + }, + }, + { + type: 'image', + source: { + data: 'bbbbbb', + mediaType: 'image/png', + type: 'base64', + }, + }, + ], + role: 'user', + }, + ]); + }); + it('correctly format tool choice', () => { bedrockClaudeAdapter.chatComplete({ executor: executorMock, diff --git a/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/bedrock/bedrock_claude_adapter.ts b/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/bedrock/bedrock_claude_adapter.ts index e34605a4c96ad..3500d12dc69fa 100644 --- a/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/bedrock/bedrock_claude_adapter.ts +++ b/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/bedrock/bedrock_claude_adapter.ts @@ -17,7 +17,7 @@ import { } from '@kbn/inference-common'; import { parseSerdeChunkMessage } from './serde_utils'; import { InferenceConnectorAdapter } from '../../types'; -import type { BedRockMessage, BedrockToolChoice } from './types'; +import type { BedRockImagePart, BedRockMessage, BedRockTextPart, BedrockToolChoice } from './types'; import { BedrockChunkMember, serdeEventstreamIntoObservable, @@ -153,7 +153,24 @@ const messagesToBedrock = (messages: Message[]): BedRockMessage[] => { case MessageRole.User: return { role: 'user' as const, - rawContent: [{ type: 'text' as const, text: message.content }], + rawContent: (typeof message.content === 'string' + ? [message.content] + : message.content + ).map((contentPart) => { + if (typeof contentPart === 'string') { + return { text: contentPart, type: 'text' } satisfies BedRockTextPart; + } else if (contentPart.type === 'text') { + return { text: contentPart.text, type: 'text' } satisfies BedRockTextPart; + } + return { + type: 'image', + source: { + data: contentPart.source.data, + mediaType: contentPart.source.mimeType, + type: 'base64', + }, + } satisfies BedRockImagePart; + }), }; case MessageRole.Assistant: return { diff --git a/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/bedrock/types.ts b/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/bedrock/types.ts index f0937a8d8ec18..805ee17c096e1 100644 --- a/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/bedrock/types.ts +++ b/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/bedrock/types.ts @@ -17,15 +17,38 @@ export interface BedRockMessage { /** * Bedrock message parts */ +export interface BedRockTextPart { + type: 'text'; + text: string; +} + +export interface BedRockToolUsePart { + type: 'tool_use'; + id: string; + name: string; + input: Record; +} + +export interface BedRockToolResultPart { + type: 'tool_result'; + tool_use_id: string; + content: string; +} + +export interface BedRockImagePart { + type: 'image'; + source: { + type: 'base64'; + mediaType: string; + data: string; + }; +} + export type BedRockMessagePart = - | { type: 'text'; text: string } - | { - type: 'tool_use'; - id: string; - name: string; - input: Record; - } - | { type: 'tool_result'; tool_use_id: string; content: string }; + | BedRockTextPart + | BedRockToolUsePart + | BedRockToolResultPart + | BedRockImagePart; export type BedrockToolChoice = { type: 'auto' } | { type: 'any' } | { type: 'tool'; name: string }; diff --git a/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/gemini/gemini_adapter.test.ts b/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/gemini/gemini_adapter.test.ts index 5024bd1f4c87e..e7eb75453e778 100644 --- a/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/gemini/gemini_adapter.test.ts +++ b/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/gemini/gemini_adapter.test.ts @@ -239,6 +239,86 @@ describe('geminiAdapter', () => { ]); }); + it('correctly formats content parts', () => { + geminiAdapter.chatComplete({ + executor: executorMock, + logger, + messages: [ + { + role: MessageRole.User, + content: [ + { + type: 'text', + text: 'question', + }, + ], + }, + { + role: MessageRole.Assistant, + content: 'answer', + }, + { + role: MessageRole.User, + content: [ + { + type: 'image', + source: { + data: 'aaaaaa', + mimeType: 'image/png', + }, + }, + { + type: 'image', + source: { + data: 'bbbbbb', + mimeType: 'image/png', + }, + }, + ], + }, + ], + }); + + expect(executorMock.invoke).toHaveBeenCalledTimes(1); + + const { messages } = getCallParams(); + expect(messages).toEqual([ + { + parts: [ + { + text: 'question', + }, + ], + role: 'user', + }, + { + parts: [ + { + text: 'answer', + }, + ], + role: 'assistant', + }, + { + parts: [ + { + inlineData: { + data: 'aaaaaa', + mimeType: 'image/png', + }, + }, + { + inlineData: { + data: 'bbbbbb', + mimeType: 'image/png', + }, + }, + ], + role: 'user', + }, + ]); + }); + it('groups messages from the same user', () => { geminiAdapter.chatComplete({ logger, diff --git a/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/gemini/gemini_adapter.ts b/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/gemini/gemini_adapter.ts index aa62f7006eac7..29b663be146d2 100644 --- a/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/gemini/gemini_adapter.ts +++ b/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/gemini/gemini_adapter.ts @@ -196,11 +196,21 @@ function messageToGeminiMapper() { case MessageRole.User: const userMessage: GeminiMessage = { role: 'user', - parts: [ - { - text: message.content, - }, - ], + parts: (typeof message.content === 'string' ? [message.content] : message.content).map( + (contentPart) => { + if (typeof contentPart === 'string') { + return { text: contentPart } satisfies Gemini.TextPart; + } else if (contentPart.type === 'text') { + return { text: contentPart.text } satisfies Gemini.TextPart; + } + return { + inlineData: { + data: contentPart.source.data, + mimeType: contentPart.source.mimeType, + }, + } satisfies Gemini.InlineDataPart; + } + ), }; return userMessage; diff --git a/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/openai/openai_adapter.test.ts b/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/openai/openai_adapter.test.ts index d93dee627ec18..c9699f006d96b 100644 --- a/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/openai/openai_adapter.test.ts +++ b/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/openai/openai_adapter.test.ts @@ -118,6 +118,86 @@ describe('openAIAdapter', () => { ]); }); + it('correctly formats messages with content parts', () => { + openAIAdapter.chatComplete({ + executor: executorMock, + logger, + messages: [ + { + role: MessageRole.User, + content: [ + { + type: 'text', + text: 'question', + }, + ], + }, + { + role: MessageRole.Assistant, + content: 'answer', + }, + { + role: MessageRole.User, + content: [ + { + type: 'image', + source: { + data: 'aaaaaa', + mimeType: 'image/png', + }, + }, + { + type: 'image', + source: { + data: 'bbbbbb', + mimeType: 'image/png', + }, + }, + ], + }, + ], + }); + + expect(executorMock.invoke).toHaveBeenCalledTimes(1); + + const { + body: { messages }, + } = getRequest(); + + expect(messages).toEqual([ + { + content: [ + { + text: 'question', + type: 'text', + }, + ], + role: 'user', + }, + { + content: 'answer', + role: 'assistant', + }, + { + content: [ + { + type: 'image_url', + image_url: { + url: 'aaaaaa', + }, + }, + { + type: 'image_url', + image_url: { + url: 'bbbbbb', + }, + }, + ], + role: 'user', + }, + ]); + }); + it('correctly formats tools and tool choice', () => { openAIAdapter.chatComplete({ ...defaultArgs, diff --git a/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/openai/to_openai.ts b/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/openai/to_openai.ts index 709b1fd4c6bfe..66792963c425f 100644 --- a/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/openai/to_openai.ts +++ b/x-pack/platform/plugins/shared/inference/server/chat_complete/adapters/openai/to_openai.ts @@ -8,6 +8,8 @@ import type OpenAI from 'openai'; import type { ChatCompletionAssistantMessageParam, + ChatCompletionContentPartImage, + ChatCompletionContentPartText, ChatCompletionMessageParam, ChatCompletionSystemMessageParam, ChatCompletionToolMessageParam, @@ -90,7 +92,23 @@ export function messagesToOpenAI({ case MessageRole.User: const userMessage: ChatCompletionUserMessageParam = { role: 'user', - content: message.content, + content: + typeof message.content === 'string' + ? message.content + : message.content.map((contentPart) => { + if (contentPart.type === 'image') { + return { + type: 'image_url', + image_url: { + url: contentPart.source.data, + }, + } satisfies ChatCompletionContentPartImage; + } + return { + text: contentPart.text, + type: 'text', + } satisfies ChatCompletionContentPartText; + }), }; return userMessage; diff --git a/x-pack/platform/plugins/shared/inference/server/chat_complete/simulated_function_calling/wrap_with_simulated_function_calling.ts b/x-pack/platform/plugins/shared/inference/server/chat_complete/simulated_function_calling/wrap_with_simulated_function_calling.ts index d2cb0bfae4999..8c207617e9bf4 100644 --- a/x-pack/platform/plugins/shared/inference/server/chat_complete/simulated_function_calling/wrap_with_simulated_function_calling.ts +++ b/x-pack/platform/plugins/shared/inference/server/chat_complete/simulated_function_calling/wrap_with_simulated_function_calling.ts @@ -52,9 +52,25 @@ export function wrapWithSimulatedFunctionCalling({ return message; }) .map((message) => { + let content = message.content; + + if (typeof content === 'string') { + content = replaceFunctionsWithTools(content); + } else if (Array.isArray(content)) { + content = content.map((contentPart) => { + if (contentPart.type === 'text') { + return { + ...contentPart, + text: replaceFunctionsWithTools(contentPart.text), + }; + } + return contentPart; + }); + } + return { ...message, - content: message.content ? replaceFunctionsWithTools(message.content) : message.content, + content, }; });