From 7ca34d40fb20b1561977a13259a541bf3051b921 Mon Sep 17 00:00:00 2001 From: Narendranath Gogineni Date: Mon, 14 Oct 2024 14:12:32 +0530 Subject: [PATCH 1/7] Bedrock converse integration fix tool choice fix tool choice fix interface add support for prompting with images changes per comments changes per comments Add back support for Cohere command and AI21 Jurrasic 2 models because they are not supported by bedrock converse API handle prompting with documents on bedrock converse support prompting with documents --- src/globals.ts | 49 + src/handlers/streamHandler.ts | 16 +- src/providers/bedrock/api.ts | 14 +- src/providers/bedrock/chatComplete.ts | 1467 +++++++---------------- src/providers/bedrock/constants.ts | 7 + src/providers/bedrock/index.ts | 95 +- src/providers/bedrock/utils.ts | 137 +-- src/providers/google-vertex-ai/utils.ts | 23 +- src/providers/types.ts | 1 + 9 files changed, 621 insertions(+), 1188 deletions(-) diff --git a/src/globals.ts b/src/globals.ts index 5d4326826..9a78a4759 100644 --- a/src/globals.ts +++ b/src/globals.ts @@ -139,3 +139,52 @@ export const MULTIPART_FORM_DATA_ENDPOINTS: endpointStrings[] = [ 'createTranscription', 'createTranslation', ]; + +export const fileExtensionMimeTypeMap = { + mp4: 'video/mp4', + jpeg: 'image/jpeg', + jpg: 'image/jpeg', + png: 'image/png', + bmp: 'image/bmp', + tiff: 'image/tiff', + webp: 'image/webp', + pdf: 'application/pdf', + csv: 'text/csv', + doc: 'application/msword', + docx: 'application/vnd.openxmlformats-officedocument.wordprocessingml.document', + xls: 'application/vnd.ms-excel', + xlsx: 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet', + html: 'text/html', + md: 'text/markdown', + mp3: 'audio/mp3', + wav: 'audio/wav', + txt: 'text/plain', + mov: 'video/mov', + mpeg: 'video/mpeg', + mpg: 'video/mpg', + avi: 'video/avi', + wmv: 'video/wmv', + mpegps: 'video/mpegps', + flv: 'video/flv', +}; + +export const imagesMimeTypes = [ + fileExtensionMimeTypeMap.jpeg, + fileExtensionMimeTypeMap.jpg, + fileExtensionMimeTypeMap.png, + fileExtensionMimeTypeMap.bmp, + fileExtensionMimeTypeMap.tiff, + fileExtensionMimeTypeMap.webp, +]; + +export const documentMimeTypes = [ + fileExtensionMimeTypeMap.pdf, + fileExtensionMimeTypeMap.csv, + fileExtensionMimeTypeMap.doc, + fileExtensionMimeTypeMap.docx, + fileExtensionMimeTypeMap.xls, + fileExtensionMimeTypeMap.xlsx, + fileExtensionMimeTypeMap.html, + fileExtensionMimeTypeMap.md, + fileExtensionMimeTypeMap.txt, +]; diff --git a/src/handlers/streamHandler.ts b/src/handlers/streamHandler.ts index a8fc462d9..36b311f4f 100644 --- a/src/handlers/streamHandler.ts +++ b/src/handlers/streamHandler.ts @@ -33,7 +33,10 @@ function getPayloadFromAWSChunk(chunk: Uint8Array): string { const payloadLength = chunkLength - headersEnd - 4; // Subtracting 4 for the message crc const payload = chunk.slice(headersEnd, headersEnd + payloadLength); - return decoder.decode(payload); + const decodedJson = JSON.parse(decoder.decode(payload)); + return decodedJson.bytes + ? Buffer.from(decodedJson.bytes, 'base64').toString() + : JSON.stringify(decodedJson); } function concatenateUint8Arrays(a: Uint8Array, b: Uint8Array): Uint8Array { @@ -60,10 +63,7 @@ export async function* readAWSStream( const data = buffer.subarray(0, expectedLength); buffer = buffer.subarray(expectedLength); expectedLength = readUInt32BE(buffer, 0); - const payload = Buffer.from( - JSON.parse(getPayloadFromAWSChunk(data)).bytes, - 'base64' - ).toString(); + const payload = getPayloadFromAWSChunk(data); if (transformFunction) { const transformedChunk = transformFunction( payload, @@ -96,11 +96,7 @@ export async function* readAWSStream( buffer = buffer.subarray(expectedLength); expectedLength = readUInt32BE(buffer, 0); - const payload = Buffer.from( - JSON.parse(getPayloadFromAWSChunk(data)).bytes, - 'base64' - ).toString(); - + const payload = getPayloadFromAWSChunk(data); if (transformFunction) { const transformedChunk = transformFunction( payload, diff --git a/src/providers/bedrock/api.ts b/src/providers/bedrock/api.ts index 3c7350a6c..92f3968c7 100644 --- a/src/providers/bedrock/api.ts +++ b/src/providers/bedrock/api.ts @@ -1,4 +1,6 @@ +import { GatewayError } from '../../errors/GatewayError'; import { ProviderAPIConfig } from '../types'; +import { bedrockInvokeModels } from './constants'; import { generateAWSHeaders } from './utils'; const BedrockAPIConfig: ProviderAPIConfig = { @@ -27,12 +29,20 @@ const BedrockAPIConfig: ProviderAPIConfig = { }, getEndpoint: ({ fn, gatewayRequestBody }) => { const { model, stream } = gatewayRequestBody; + if (!model) throw new GatewayError('Model is required'); let mappedFn = fn; if (stream) { mappedFn = `stream-${fn}`; } - const endpoint = `/model/${model}/invoke`; - const streamEndpoint = `/model/${model}/invoke-with-response-stream`; + let endpoint = `/model/${model}/invoke`; + let streamEndpoint = `/model/${model}/invoke-with-response-stream`; + if ( + (mappedFn === 'chatComplete' || mappedFn === 'stream-chatComplete') && + !bedrockInvokeModels.includes(model) + ) { + endpoint = `/model/${model}/converse`; + streamEndpoint = `/model/${model}/converse-stream`; + } switch (mappedFn) { case 'chatComplete': { return endpoint; diff --git a/src/providers/bedrock/chatComplete.ts b/src/providers/bedrock/chatComplete.ts index a4f6ecf7e..16038d411 100644 --- a/src/providers/bedrock/chatComplete.ts +++ b/src/providers/bedrock/chatComplete.ts @@ -1,5 +1,5 @@ -import { BEDROCK } from '../../globals'; -import { ContentType, Message, Params } from '../../types/requestBody'; +import { BEDROCK, documentMimeTypes, imagesMimeTypes } from '../../globals'; +import { Message, Params, ToolCall } from '../../types/requestBody'; import { ChatCompletionResponse, ErrorResponse, @@ -13,610 +13,284 @@ import { BedrockAI21CompleteResponse, BedrockCohereCompleteResponse, BedrockCohereStreamChunk, - BedrockLlamaCompleteResponse, - BedrockLlamaStreamChunk, - BedrockTitanCompleteResponse, - BedrockTitanStreamChunk, - BedrockMistralCompleteResponse, - BedrocMistralStreamChunk, } from './complete'; import { BedrockErrorResponse } from './embed'; import { - transformMessagesForLLama2Prompt, - transformMessagesForLLama3Prompt, - transformMessagesForMistralPrompt, + transformAdditionalModelRequestFields, + transformInferenceConfig, } from './utils'; -interface AnthropicTool { - name: string; - description: string; - input_schema: { - type: string; - properties: Record< - string, - { - type: string; - description: string; - } - >; - required: string[]; +export interface BedrockChatCompletionsParams extends Params { + additionalModelRequestFields?: Record; + additionalModelResponseFieldPaths?: string[]; + guardrailConfig?: { + guardrailIdentifier: string; + guardrailVersion: string; + trace?: string; }; + anthropic_version?: string; + countPenalty?: number; } -interface AnthropicToolResultContentItem { - type: 'tool_result'; - tool_use_id: string; - content?: string; -} - -type AnthropicMessageContentItem = AnthropicToolResultContentItem | ContentType; - -interface AnthropicMessage extends Message { - content?: string | AnthropicMessageContentItem[]; -} - -interface AnthorpicTextContentItem { - type: 'text'; - text: string; -} - -interface AnthropicToolContentItem { - type: 'tool_use'; - name: string; - id: string; - input: Record; -} - -type AnthropicContentItem = AnthorpicTextContentItem | AnthropicToolContentItem; - -const transformAssistantMessageForAnthropic = ( - msg: Message -): AnthropicMessage => { - let content: AnthropicContentItem[] = []; - const containsToolCalls = msg.tool_calls && msg.tool_calls.length; - - if (msg.content && typeof msg.content === 'string') { - content.push({ - type: 'text', - text: msg.content, - }); - } else if ( - msg.content && - typeof msg.content === 'object' && - msg.content.length - ) { - if (msg.content[0].text) { - content.push({ - type: 'text', - text: msg.content[0].text, - }); - } - } - if (containsToolCalls) { - msg.tool_calls.forEach((toolCall: any) => { - content.push({ - type: 'tool_use', - name: toolCall.function.name, - id: toolCall.id, - input: JSON.parse(toolCall.function.arguments), +const getMessageTextContentArray = (message: Message): { text: string }[] => { + if (message.content && typeof message.content === 'object') { + return message.content + .filter((item) => item.type === 'text') + .map((item) => { + return { + text: item.text || '', + }; }); - }); } - return { - role: msg.role, - content, - }; + return [ + { + text: message.content || '', + }, + ]; }; -const transformToolMessageForAnthropic = (msg: Message): AnthropicMessage => { - return { - role: 'user', - content: [ +const getMessageContent = (message: Message) => { + if (!message.content) return []; + if (message.role === 'tool') { + return [ { - type: 'tool_result', - tool_use_id: msg.tool_call_id, - content: msg.content as string, + toolResult: { + content: getMessageTextContentArray(message), + toolUseId: message.tool_call_id, + }, }, - ], - }; + ]; + } + const out = []; + // if message is a string, return a single element array with the text + if (typeof message.content === 'string') { + out.push({ + text: message.content, + }); + } else { + message.content.forEach((item) => { + if (item.type === 'text') { + out.push({ + text: item.text || '', + }); + } else if (item.type === 'image_url' && item.image_url) { + const mimetypeParts = item.image_url.url.split(';'); + const mimeType = mimetypeParts[0].split(':')[1]; + const fileFormat = mimeType.split('/')[1]; + const bytes = mimetypeParts[1].split(',')[1]; + if (imagesMimeTypes.includes(mimeType)) { + out.push({ + image: { + source: { + bytes, + }, + format: fileFormat, + }, + }); + } else if (documentMimeTypes.includes(mimeType)) { + out.push({ + document: { + format: fileFormat, + name: crypto.randomUUID(), + source: { + bytes, + }, + }, + }); + } + } + }); + } + + // If message is an array of objects, handle text content, tool calls, tool results, this would be much cleaner if portkeys chat create object were a union type + message.tool_calls?.forEach((toolCall: ToolCall) => { + out.push({ + toolUse: { + name: toolCall.function.name, + input: JSON.parse(toolCall.function.arguments), + toolUseId: toolCall.id, + }, + }); + }); + return out; }; -export const BedrockAnthropicChatCompleteConfig: ProviderConfig = { +// refer: https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html +export const BedrockConverseChatCompleteConfig: ProviderConfig = { messages: [ { param: 'messages', required: true, - transform: (params: Params) => { - let messages: AnthropicMessage[] = []; - // Transform the chat messages into a simple prompt - if (!!params.messages) { - params.messages.forEach((msg) => { - if (msg.role === 'system') return; - - if (msg.role === 'assistant') { - messages.push(transformAssistantMessageForAnthropic(msg)); - } else if ( - msg.content && - typeof msg.content === 'object' && - msg.content.length - ) { - const transformedMessage: Record = { - role: msg.role, - content: [], - }; - msg.content.forEach((item) => { - if (item.type === 'text') { - transformedMessage.content.push({ - type: item.type, - text: item.text, - }); - } else if ( - item.type === 'image_url' && - item.image_url && - item.image_url.url - ) { - const parts = item.image_url.url.split(';'); - if (parts.length === 2) { - const base64ImageParts = parts[1].split(','); - const base64Image = base64ImageParts[1]; - const mediaTypeParts = parts[0].split(':'); - if (mediaTypeParts.length === 2 && base64Image) { - const mediaType = mediaTypeParts[1]; - transformedMessage.content.push({ - type: 'image', - source: { - type: 'base64', - media_type: mediaType, - data: base64Image, - }, - }); - } - } - } - }); - messages.push(transformedMessage as Message); - } else if (msg.role === 'tool') { - // even though anthropic supports images in tool results, openai doesn't support it yet - messages.push(transformToolMessageForAnthropic(msg)); - } else { - messages.push({ - role: msg.role, - content: msg.content, - }); - } + transform: (params: BedrockChatCompletionsParams) => { + if (!params.messages) return []; + return params.messages + .filter((msg) => msg.role !== 'system') + .map((msg) => { + return { + role: msg.role === 'assistant' ? 'assistant' : 'user', + content: getMessageContent(msg), + }; }); - } - - return messages; }, }, { param: 'system', required: false, - transform: (params: Params) => { - let systemMessage: string = ''; - // Transform the chat messages into a simple prompt - if (!!params.messages) { - params.messages.forEach((msg) => { - if ( - msg.role === 'system' && - msg.content && - typeof msg.content === 'object' && - msg.content[0].text - ) { - systemMessage = msg.content[0].text; - } else if ( - msg.role === 'system' && - typeof msg.content === 'string' - ) { - systemMessage = msg.content; - } - }); - } - return systemMessage; + transform: (params: BedrockChatCompletionsParams) => { + if (!params.messages) return; + const systemMessages = params.messages.reduce( + (acc: { text: string }[], msg) => { + if (msg.role === 'system') + return acc.concat(...getMessageTextContentArray(msg)); + return acc; + }, + [] + ); + if (!systemMessages.length) return; + return systemMessages; }, }, ], tools: { - param: 'tools', - required: false, - transform: (params: Params) => { - let tools: AnthropicTool[] = []; - if (params.tools) { - params.tools.forEach((tool) => { - if (tool.function) { - tools.push({ + param: 'toolConfig', + transform: (params: BedrockChatCompletionsParams) => { + const toolConfig = { + tools: params.tools?.map((tool) => { + if (!tool.function) return; + return { + toolSpec: { name: tool.function.name, - description: tool.function?.description || '', - input_schema: { - type: tool.function.parameters?.type || 'object', - properties: tool.function.parameters?.properties || {}, - required: tool.function.parameters?.required || [], - }, - }); - } - }); - } - return tools; - }, - }, - // None is not supported by Anthropic, defaults to auto - tool_choice: { - param: 'tool_choice', - required: false, - transform: (params: Params) => { + description: tool.function.description, + inputSchema: { json: tool.function.parameters }, + }, + }; + }), + }; + let toolChoice = undefined; if (params.tool_choice) { - if (typeof params.tool_choice === 'string') { - if (params.tool_choice === 'required') return { type: 'any' }; - else if (params.tool_choice === 'auto') return { type: 'auto' }; - } else if (typeof params.tool_choice === 'object') { - return { type: 'tool', name: params.tool_choice.function.name }; - } - } - return null; - }, - }, - max_tokens: { - param: 'max_tokens', - required: true, - }, - max_completion_tokens: { - param: 'max_tokens', - }, - temperature: { - param: 'temperature', - default: 1, - min: 0, - max: 1, - }, - top_p: { - param: 'top_p', - default: -1, - min: -1, - }, - top_k: { - param: 'top_k', - default: -1, - }, - stop: { - param: 'stop_sequences', - transform: (params: Params) => { - if (params.stop === null) { - return []; - } - return params.stop; - }, - }, - user: { - param: 'metadata.user_id', - }, - anthropic_version: { - param: 'anthropic_version', - required: true, - default: 'bedrock-2023-05-31', - }, -}; - -export const BedrockCohereChatCompleteConfig: ProviderConfig = { - messages: { - param: 'prompt', - required: true, - transform: (params: Params) => { - let prompt: string = ''; - if (!!params.messages) { - let messages: Message[] = params.messages; - messages.forEach((msg, index) => { - if (index === 0 && msg.role === 'system') { - prompt += `system: ${messages}\n`; - } else if (msg.role == 'user') { - prompt += `user: ${msg.content}\n`; - } else if (msg.role == 'assistant') { - prompt += `assistant: ${msg.content}\n`; - } else { - prompt += `${msg.role}: ${msg.content}\n`; - } - }); - prompt += 'Assistant:'; - } - return prompt; - }, - }, - max_tokens: { - param: 'max_tokens', - default: 20, - min: 1, - }, - max_completion_tokens: { - param: 'max_tokens', - default: 20, - min: 1, - }, - temperature: { - param: 'temperature', - default: 0.75, - min: 0, - max: 5, - }, - top_p: { - param: 'p', - default: 0.75, - min: 0, - max: 1, - }, - top_k: { - param: 'k', - default: 0, - max: 500, - }, - frequency_penalty: { - param: 'frequency_penalty', - default: 0, - min: 0, - max: 1, - }, - presence_penalty: { - param: 'presence_penalty', - default: 0, - min: 0, - max: 1, - }, - logit_bias: { - param: 'logit_bias', - }, - n: { - param: 'num_generations', - default: 1, - min: 1, - max: 5, - }, - stop: { - param: 'end_sequences', - }, - stream: { - param: 'stream', - }, -}; - -export const BedrockLlama2ChatCompleteConfig: ProviderConfig = { - messages: { - param: 'prompt', - required: true, - transform: (params: Params) => { - if (!params.messages) return ''; - return transformMessagesForLLama2Prompt(params.messages); - }, - }, - max_tokens: { - param: 'max_gen_len', - default: 512, - min: 1, - max: 2048, - }, - max_completion_tokens: { - param: 'max_gen_len', - default: 512, - min: 1, - max: 2048, - }, - temperature: { - param: 'temperature', - default: 0.5, - min: 0, - max: 1, - }, - top_p: { - param: 'top_p', - default: 0.9, - min: 0, - max: 1, - }, -}; - -export const BedrockLlama3ChatCompleteConfig: ProviderConfig = { - messages: { - param: 'prompt', - required: true, - transform: (params: Params) => { - if (!params.messages) return ''; - return transformMessagesForLLama3Prompt(params.messages); - }, - }, - max_tokens: { - param: 'max_gen_len', - default: 512, - min: 1, - }, - temperature: { - param: 'temperature', - default: 0.5, - min: 0, - max: 1, - }, - top_p: { - param: 'top_p', - default: 0.9, - min: 0, - max: 1, - }, -}; - -export const BedrockMistralChatCompleteConfig: ProviderConfig = { - messages: { - param: 'prompt', - required: true, - transform: (params: Params) => { - let prompt: string = ''; - if (!!params.messages) - prompt = transformMessagesForMistralPrompt(params.messages); - return prompt; - }, - }, - max_tokens: { - param: 'max_tokens', - default: 20, - min: 1, - }, - max_completion_tokens: { - param: 'max_tokens', - default: 20, - min: 1, - }, - temperature: { - param: 'temperature', - default: 0.75, - min: 0, - max: 5, - }, - top_p: { - param: 'top_p', - default: 0.75, - min: 0, - max: 1, - }, - top_k: { - param: 'top_k', - default: 0, - max: 200, - }, - stop: { - param: 'stop', - }, -}; - -const transformTitanGenerationConfig = (params: Params) => { - const generationConfig: Record = {}; - if (params['temperature']) { - generationConfig['temperature'] = params['temperature']; - } - if (params['top_p']) { - generationConfig['topP'] = params['top_p']; - } - if (params['max_tokens']) { - generationConfig['maxTokenCount'] = params['max_tokens']; - } - if (params['max_completion_tokens']) { - generationConfig['maxTokenCount'] = params['max_completion_tokens']; - } - if (params['stop']) { - generationConfig['stopSequences'] = params['stop']; - } - return generationConfig; -}; - -export const BedrockTitanChatompleteConfig: ProviderConfig = { - messages: { - param: 'inputText', - required: true, - transform: (params: Params) => { - let prompt: string = ''; - if (!!params.messages) { - let messages: Message[] = params.messages; - messages.forEach((msg, index) => { - if (index === 0 && msg.role === 'system') { - prompt += `system: ${messages}\n`; - } else if (msg.role == 'user') { - prompt += `user: ${msg.content}\n`; - } else if (msg.role == 'assistant') { - prompt += `assistant: ${msg.content}\n`; - } else { - prompt += `${msg.role}: ${msg.content}\n`; + if (typeof params.tool_choice === 'object') { + toolChoice = { + tool: { + name: params.tool_choice.function.name, + }, + }; + } else if (typeof params.tool_choice === 'string') { + if (params.tool_choice === 'required') { + toolChoice = { + any: {}, + }; + } else if (params.tool_choice === 'auto') { + toolChoice = { + auto: {}, + }; } - }); - prompt += 'Assistant:'; + } } - return prompt; + return { ...toolConfig, toolChoice }; }, }, - temperature: { - param: 'textGenerationConfig', - transform: (params: Params) => transformTitanGenerationConfig(params), - }, - max_tokens: { - param: 'textGenerationConfig', - transform: (params: Params) => transformTitanGenerationConfig(params), - }, - max_completion_tokens: { - param: 'textGenerationConfig', - transform: (params: Params) => transformTitanGenerationConfig(params), - }, - top_p: { - param: 'textGenerationConfig', - transform: (params: Params) => transformTitanGenerationConfig(params), + guardrailConfig: { + param: 'guardrailConfig', + required: false, }, -}; - -export const BedrockAI21ChatCompleteConfig: ProviderConfig = { - messages: { - param: 'prompt', - required: true, - transform: (params: Params) => { - let prompt: string = ''; - if (!!params.messages) { - let messages: Message[] = params.messages; - messages.forEach((msg, index) => { - if (index === 0 && msg.role === 'system') { - prompt += `system: ${messages}\n`; - } else if (msg.role == 'user') { - prompt += `user: ${msg.content}\n`; - } else if (msg.role == 'assistant') { - prompt += `assistant: ${msg.content}\n`; - } else { - prompt += `${msg.role}: ${msg.content}\n`; - } - }); - prompt += 'Assistant:'; - } - return prompt; - }, + additionalModelResponseFieldPaths: { + param: 'additionalModelResponseFieldPaths', + required: false, }, max_tokens: { - param: 'maxTokens', - default: 200, + param: 'inferenceConfig', + transform: (params: BedrockChatCompletionsParams) => + transformInferenceConfig(params), }, - max_completion_tokens: { - param: 'maxTokens', - default: 200, + stop: { + param: 'inferenceConfig', + transform: (params: BedrockChatCompletionsParams) => + transformInferenceConfig(params), }, temperature: { - param: 'temperature', - default: 0.7, - min: 0, - max: 1, + param: 'inferenceConfig', + transform: (params: BedrockChatCompletionsParams) => + transformInferenceConfig(params), }, top_p: { - param: 'topP', - default: 1, + param: 'inferenceConfig', + transform: (params: BedrockChatCompletionsParams) => + transformInferenceConfig(params), }, - stop: { - param: 'stopSequences', + additionalModelRequestFields: { + param: 'additionalModelRequestFields', + transform: (params: BedrockChatCompletionsParams) => + transformAdditionalModelRequestFields(params), }, - presence_penalty: { - param: 'presencePenalty', - transform: (params: Params) => { - return { - scale: params.presence_penalty, - }; - }, + top_k: { + param: 'additionalModelRequestFields', + transform: (params: BedrockChatCompletionsParams) => + transformAdditionalModelRequestFields(params), + }, + anthropic_version: { + param: 'additionalModelRequestFields', + transform: (params: BedrockChatCompletionsParams) => + transformAdditionalModelRequestFields(params), }, frequency_penalty: { - param: 'frequencyPenalty', - transform: (params: Params) => { - return { - scale: params.frequency_penalty, - }; - }, + param: 'additionalModelRequestFields', + transform: (params: BedrockChatCompletionsParams) => + transformAdditionalModelRequestFields(params), }, - countPenalty: { - param: 'countPenalty', + presence_penalty: { + param: 'additionalModelRequestFields', + transform: (params: BedrockChatCompletionsParams) => + transformAdditionalModelRequestFields(params), }, - frequencyPenalty: { - param: 'frequencyPenalty', + logit_bias: { + param: 'additionalModelRequestFields', + transform: (params: BedrockChatCompletionsParams) => + transformAdditionalModelRequestFields(params), }, - presencePenalty: { - param: 'presencePenalty', + n: { + param: 'additionalModelRequestFields', + transform: (params: BedrockChatCompletionsParams) => + transformAdditionalModelRequestFields(params), + }, + stream: { + param: 'additionalModelRequestFields', + transform: (params: BedrockChatCompletionsParams) => + transformAdditionalModelRequestFields(params), + }, + countPenalty: { + param: 'additionalModelRequestFields', + transform: (params: BedrockChatCompletionsParams) => + transformAdditionalModelRequestFields(params), }, }; +interface BedrockChatCompletionResponse { + metrics: { + latencyMs: number; + }; + output: { + message: { + role: string; + content: [ + { + text: string; + toolUse: { + toolUseId: string; + name: string; + input: object; + }; + }, + ]; + }; + }; + stopReason: string; + usage: { + inputTokens: number; + outputTokens: number; + totalTokens: number; + }; +} + export const BedrockErrorResponseTransform: ( response: BedrockErrorResponse ) => ErrorResponse | undefined = (response) => { @@ -630,198 +304,8 @@ export const BedrockErrorResponseTransform: ( return undefined; }; -export const BedrockLlamaChatCompleteResponseTransform: ( - response: BedrockLlamaCompleteResponse | BedrockErrorResponse, - responseStatus: number -) => ChatCompletionResponse | ErrorResponse = (response, responseStatus) => { - if (responseStatus !== 200) { - const errorResposne = BedrockErrorResponseTransform( - response as BedrockErrorResponse - ); - if (errorResposne) return errorResposne; - } - - if ('generation' in response) { - return { - id: Date.now().toString(), - object: 'chat.completion', - created: Math.floor(Date.now() / 1000), - model: '', - provider: BEDROCK, - choices: [ - { - index: 0, - message: { - role: 'assistant', - content: response.generation, - }, - finish_reason: response.stop_reason, - }, - ], - usage: { - prompt_tokens: response.prompt_token_count, - completion_tokens: response.generation_token_count, - total_tokens: - response.prompt_token_count + response.generation_token_count, - }, - }; - } - - return generateInvalidProviderResponseError(response, BEDROCK); -}; - -export const BedrockLlamaChatCompleteStreamChunkTransform: ( - response: string, - fallbackId: string -) => string | string[] = (responseChunk, fallbackId) => { - let chunk = responseChunk.trim(); - chunk = chunk.trim(); - const parsedChunk: BedrockLlamaStreamChunk = JSON.parse(chunk); - - if (parsedChunk.stop_reason) { - return [ - `data: ${JSON.stringify({ - id: fallbackId, - object: 'text_completion', - created: Math.floor(Date.now() / 1000), - model: '', - provider: BEDROCK, - choices: [ - { - delta: {}, - index: 0, - logprobs: null, - finish_reason: parsedChunk.stop_reason, - }, - ], - usage: { - prompt_tokens: - parsedChunk['amazon-bedrock-invocationMetrics'].inputTokenCount, - completion_tokens: - parsedChunk['amazon-bedrock-invocationMetrics'].outputTokenCount, - total_tokens: - parsedChunk['amazon-bedrock-invocationMetrics'].inputTokenCount + - parsedChunk['amazon-bedrock-invocationMetrics'].outputTokenCount, - }, - })}\n\n`, - `data: [DONE]\n\n`, - ]; - } - - return `data: ${JSON.stringify({ - id: fallbackId, - object: 'chat.completion.chunk', - created: Math.floor(Date.now() / 1000), - model: '', - provider: BEDROCK, - choices: [ - { - index: 0, - delta: { - role: 'assistant', - content: parsedChunk.generation, - }, - finish_reason: null, - }, - ], - })}\n\n`; -}; - -export const BedrockTitanChatCompleteResponseTransform: ( - response: BedrockTitanCompleteResponse | BedrockErrorResponse, - responseStatus: number -) => ChatCompletionResponse | ErrorResponse = (response, responseStatus) => { - if (responseStatus !== 200) { - const errorResposne = BedrockErrorResponseTransform( - response as BedrockErrorResponse - ); - if (errorResposne) return errorResposne; - } - - if ('results' in response) { - const completionTokens = response.results - .map((r) => r.tokenCount) - .reduce((partialSum, a) => partialSum + a, 0); - return { - id: Date.now().toString(), - object: 'chat.completion', - created: Math.floor(Date.now() / 1000), - model: '', - provider: BEDROCK, - choices: response.results.map((generation, index) => ({ - index: index, - message: { - role: 'assistant', - content: generation.outputText, - }, - finish_reason: generation.completionReason, - })), - usage: { - prompt_tokens: response.inputTextTokenCount, - completion_tokens: completionTokens, - total_tokens: response.inputTextTokenCount + completionTokens, - }, - }; - } - - return generateInvalidProviderResponseError(response, BEDROCK); -}; - -export const BedrockTitanChatCompleteStreamChunkTransform: ( - response: string, - fallbackId: string -) => string | string[] = (responseChunk, fallbackId) => { - let chunk = responseChunk.trim(); - chunk = chunk.trim(); - const parsedChunk: BedrockTitanStreamChunk = JSON.parse(chunk); - - return [ - `data: ${JSON.stringify({ - id: fallbackId, - object: 'chat.completion.chunk', - created: Math.floor(Date.now() / 1000), - model: '', - provider: BEDROCK, - choices: [ - { - index: 0, - delta: { - role: 'assistant', - content: parsedChunk.outputText, - }, - finish_reason: null, - }, - ], - })}\n\n`, - `data: ${JSON.stringify({ - id: fallbackId, - object: 'chat.completion.chunk', - created: Math.floor(Date.now() / 1000), - model: '', - provider: BEDROCK, - choices: [ - { - index: 0, - delta: {}, - finish_reason: parsedChunk.completionReason, - }, - ], - usage: { - prompt_tokens: - parsedChunk['amazon-bedrock-invocationMetrics'].inputTokenCount, - completion_tokens: - parsedChunk['amazon-bedrock-invocationMetrics'].outputTokenCount, - total_tokens: - parsedChunk['amazon-bedrock-invocationMetrics'].inputTokenCount + - parsedChunk['amazon-bedrock-invocationMetrics'].outputTokenCount, - }, - })}\n\n`, - `data: [DONE]\n\n`, - ]; -}; - -export const BedrockAI21ChatCompleteResponseTransform: ( - response: BedrockAI21CompleteResponse | BedrockErrorResponse, +export const BedrockChatCompleteResponseTransform: ( + response: BedrockChatCompletionResponse | BedrockErrorResponse, responseStatus: number, responseHeaders: Headers ) => ChatCompletionResponse | ErrorResponse = ( @@ -830,176 +314,92 @@ export const BedrockAI21ChatCompleteResponseTransform: ( responseHeaders ) => { if (responseStatus !== 200) { - const errorResposne = BedrockErrorResponseTransform( + const errorResponse = BedrockErrorResponseTransform( response as BedrockErrorResponse ); - if (errorResposne) return errorResposne; + if (errorResponse) return errorResponse; } - if ('completions' in response) { - const prompt_tokens = - Number(responseHeaders.get('X-Amzn-Bedrock-Input-Token-Count')) || 0; - const completion_tokens = - Number(responseHeaders.get('X-Amzn-Bedrock-Output-Token-Count')) || 0; - return { - id: response.id.toString(), + if ('output' in response) { + const responseObj: ChatCompletionResponse = { + id: Date.now().toString(), object: 'chat.completion', created: Math.floor(Date.now() / 1000), model: '', provider: BEDROCK, - choices: response.completions.map((completion, index) => ({ - index: index, - message: { - role: 'assistant', - content: completion.data.text, - }, - finish_reason: completion.finishReason?.reason, - })), - usage: { - prompt_tokens: prompt_tokens, - completion_tokens: completion_tokens, - total_tokens: prompt_tokens + completion_tokens, - }, - }; - } - - return generateInvalidProviderResponseError(response, BEDROCK); -}; - -interface BedrockAnthropicChatCompleteResponse { - id: string; - type: string; - role: string; - content: AnthropicContentItem[]; - stop_reason: string; - model: string; - stop_sequence: null | string; -} - -export const BedrockAnthropicChatCompleteResponseTransform: ( - response: BedrockAnthropicChatCompleteResponse | BedrockErrorResponse, - responseStatus: number, - responseHeaders: Headers -) => ChatCompletionResponse | ErrorResponse = ( - response, - responseStatus, - responseHeaders -) => { - if (responseStatus !== 200) { - const errorResposne = BedrockErrorResponseTransform( - response as BedrockErrorResponse - ); - if (errorResposne) return errorResposne; - } - - if ('content' in response) { - const prompt_tokens = - Number(responseHeaders.get('X-Amzn-Bedrock-Input-Token-Count')) || 0; - const completion_tokens = - Number(responseHeaders.get('X-Amzn-Bedrock-Output-Token-Count')) || 0; - - let content = ''; - if (response.content.length && response.content[0].type === 'text') { - content = response.content[0].text; - } - - let toolCalls: any = []; - response.content.forEach((item) => { - if (item.type === 'tool_use') { - toolCalls.push({ - id: item.id, - type: 'function', - function: { - name: item.name, - arguments: JSON.stringify(item.input), - }, - }); - } - }); - - return { - id: response.id, - object: 'chat.completion', - created: Math.floor(Date.now() / 1000), - model: response.model, - provider: BEDROCK, choices: [ { + index: 0, message: { role: 'assistant', - content, - tool_calls: toolCalls.length ? toolCalls : undefined, + content: response.output.message.content + .filter((content) => content.text) + .reduce((acc, content) => acc + content.text + '\n', ''), }, - index: 0, - logprobs: null, - finish_reason: response.stop_reason, + finish_reason: response.stopReason, }, ], usage: { - prompt_tokens: prompt_tokens, - completion_tokens: completion_tokens, - total_tokens: prompt_tokens + completion_tokens, + prompt_tokens: response.usage.inputTokens, + completion_tokens: response.usage.outputTokens, + total_tokens: response.usage.totalTokens, }, }; + const toolCalls = response.output.message.content + .filter((content) => content.toolUse) + .map((content) => ({ + id: content.toolUse.toolUseId, + type: 'function', + function: { + name: content.toolUse.name, + arguments: content.toolUse.input, + }, + })); + if (toolCalls.length > 0) + responseObj.choices[0].message.tool_calls = toolCalls; + return responseObj; } return generateInvalidProviderResponseError(response, BEDROCK); }; -interface BedrockAnthropicChatCompleteStreamResponse { - type: string; - index: number; - delta: { - type: string; +export interface BedrockChatCompleteStreamChunk { + contentBlockIndex?: number; + delta?: { text: string; - partial_json?: string; - stop_reason?: string; + toolUse: { + toolUseId: string; + name: string; + input: object; + }; }; - content_block?: { - type: string; - id?: string; - text?: string; - name?: string; - input?: {}; + stopReason?: string; + metrics?: { + latencyMs: number; }; - 'amazon-bedrock-invocationMetrics': { - inputTokenCount: number; - outputTokenCount: number; - invocationLatency: number; - firstByteLatency: number; + usage?: { + inputTokens: number; + outputTokens: number; + totalTokens: number; }; } -export const BedrockAnthropicChatCompleteStreamChunkTransform: ( - response: string, - fallbackId: string, - streamState: Record -) => string | string[] | undefined = ( - responseChunk, - fallbackId, - streamState -) => { - let chunk = responseChunk.trim(); - - const parsedChunk: BedrockAnthropicChatCompleteStreamResponse = - JSON.parse(chunk); - if ( - parsedChunk.type === 'ping' || - parsedChunk.type === 'message_start' || - parsedChunk.type === 'content_block_stop' - ) { - return []; - } +interface BedrockStreamState { + stopReason?: string; +} - if ( - parsedChunk.type === 'content_block_start' && - parsedChunk.content_block?.type === 'text' - ) { - streamState.containsChainOfThoughtMessage = true; - return; +// refer: https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ConverseStream.html +export const BedrockChatCompleteStreamChunkTransform: ( + response: string, + fallbackId: string, + streamState: BedrockStreamState +) => string | string[] = (responseChunk, fallbackId, streamState) => { + const parsedChunk: BedrockChatCompleteStreamChunk = JSON.parse(responseChunk); + if (parsedChunk.stopReason) { + streamState.stopReason = parsedChunk.stopReason; } - if (parsedChunk.type === 'message_stop') { + if (parsedChunk.usage) { return [ `data: ${JSON.stringify({ id: fallbackId, @@ -1011,71 +411,27 @@ export const BedrockAnthropicChatCompleteStreamChunkTransform: ( { index: 0, delta: {}, - finish_reason: parsedChunk.delta?.stop_reason, + finish_reason: streamState.stopReason, }, ], usage: { - prompt_tokens: - parsedChunk['amazon-bedrock-invocationMetrics'].inputTokenCount, - completion_tokens: - parsedChunk['amazon-bedrock-invocationMetrics'].outputTokenCount, - total_tokens: - parsedChunk['amazon-bedrock-invocationMetrics'].inputTokenCount + - parsedChunk['amazon-bedrock-invocationMetrics'].outputTokenCount, + prompt_tokens: parsedChunk.usage.inputTokens, + completion_tokens: parsedChunk.usage.outputTokens, + total_tokens: parsedChunk.usage.totalTokens, }, })}\n\n`, - 'data: [DONE]\n\n', - ]; - } - - if (parsedChunk.delta?.stop_reason) { - return [ - `data: ${JSON.stringify({ - id: fallbackId, - object: 'chat.completion.chunk', - created: Math.floor(Date.now() / 1000), - model: '', - provider: BEDROCK, - choices: [ - { - delta: { - content: parsedChunk.delta?.text, - }, - index: 0, - logprobs: null, - finish_reason: parsedChunk.delta?.stop_reason ?? null, - }, - ], - })}\n\n`, + `data: [DONE]\n\n`, ]; } const toolCalls = []; - const isToolBlockStart: boolean = - parsedChunk.type === 'content_block_start' && - !!parsedChunk.content_block?.id; - const isToolBlockDelta: boolean = - parsedChunk.type === 'content_block_delta' && - !!parsedChunk.delta.partial_json; - const toolIndex: number = streamState.containsChainOfThoughtMessage - ? parsedChunk.index - 1 - : parsedChunk.index; - - if (isToolBlockStart && parsedChunk.content_block) { + if (parsedChunk.delta?.toolUse) { toolCalls.push({ - index: toolIndex, - id: parsedChunk.content_block.id, + id: parsedChunk.delta.toolUse.toolUseId, type: 'function', function: { - name: parsedChunk.content_block.name, - arguments: '', - }, - }); - } else if (isToolBlockDelta) { - toolCalls.push({ - index: toolIndex, - function: { - arguments: parsedChunk.delta.partial_json, + name: parsedChunk.delta.toolUse.name, + arguments: parsedChunk.delta.toolUse.input, }, }); } @@ -1088,18 +444,97 @@ export const BedrockAnthropicChatCompleteStreamChunkTransform: ( provider: BEDROCK, choices: [ { + index: parsedChunk.contentBlockIndex ?? 0, delta: { + role: 'assistant', content: parsedChunk.delta?.text, - tool_calls: toolCalls.length ? toolCalls : undefined, + tool_calls: toolCalls, }, - index: 0, - logprobs: null, - finish_reason: parsedChunk.delta?.stop_reason ?? null, }, ], })}\n\n`; }; +export const BedrockCohereChatCompleteConfig: ProviderConfig = { + messages: { + param: 'prompt', + required: true, + transform: (params: Params) => { + let prompt: string = ''; + if (!!params.messages) { + let messages: Message[] = params.messages; + messages.forEach((msg, index) => { + if (index === 0 && msg.role === 'system') { + prompt += `system: ${messages}\n`; + } else if (msg.role == 'user') { + prompt += `user: ${msg.content}\n`; + } else if (msg.role == 'assistant') { + prompt += `assistant: ${msg.content}\n`; + } else { + prompt += `${msg.role}: ${msg.content}\n`; + } + }); + prompt += 'Assistant:'; + } + return prompt; + }, + }, + max_tokens: { + param: 'max_tokens', + default: 20, + min: 1, + }, + max_completion_tokens: { + param: 'max_tokens', + default: 20, + min: 1, + }, + temperature: { + param: 'temperature', + default: 0.75, + min: 0, + max: 5, + }, + top_p: { + param: 'p', + default: 0.75, + min: 0, + max: 1, + }, + top_k: { + param: 'k', + default: 0, + max: 500, + }, + frequency_penalty: { + param: 'frequency_penalty', + default: 0, + min: 0, + max: 1, + }, + presence_penalty: { + param: 'presence_penalty', + default: 0, + min: 0, + max: 1, + }, + logit_bias: { + param: 'logit_bias', + }, + n: { + param: 'num_generations', + default: 1, + min: 1, + max: 5, + }, + stop: { + param: 'end_sequences', + }, + stream: { + param: 'stream', + }, +}; + export const BedrockCohereChatCompleteResponseTransform: ( response: BedrockCohereCompleteResponse | BedrockErrorResponse, responseStatus: number, @@ -1204,8 +639,80 @@ export const BedrockCohereChatCompleteStreamChunkTransform: ( })}\n\n`; }; -export const BedrockMistralChatCompleteResponseTransform: ( - response: BedrockMistralCompleteResponse | BedrockErrorResponse, +export const BedrockAI21ChatCompleteConfig: ProviderConfig = { + messages: { + param: 'prompt', + required: true, + transform: (params: Params) => { + let prompt: string = ''; + if (!!params.messages) { + let messages: Message[] = params.messages; + messages.forEach((msg, index) => { + if (index === 0 && msg.role === 'system') { + prompt += `system: ${messages}\n`; + } else if (msg.role == 'user') { + prompt += `user: ${msg.content}\n`; + } else if (msg.role == 'assistant') { + prompt += `assistant: ${msg.content}\n`; + } else { + prompt += `${msg.role}: ${msg.content}\n`; + } + }); + prompt += 'Assistant:'; + } + return prompt; + }, + }, + max_tokens: { + param: 'maxTokens', + default: 200, + }, + max_completion_tokens: { + param: 'maxTokens', + default: 200, + }, + temperature: { + param: 'temperature', + default: 0.7, + min: 0, + max: 1, + }, + top_p: { + param: 'topP', + default: 1, + }, + stop: { + param: 'stopSequences', + }, + presence_penalty: { + param: 'presencePenalty', + transform: (params: Params) => { + return { + scale: params.presence_penalty, + }; + }, + }, + frequency_penalty: { + param: 'frequencyPenalty', + transform: (params: Params) => { + return { + scale: params.frequency_penalty, + }; + }, + }, + countPenalty: { + param: 'countPenalty', + }, + frequencyPenalty: { + param: 'frequencyPenalty', + }, + presencePenalty: { + param: 'presencePenalty', + }, +}; + +export const BedrockAI21ChatCompleteResponseTransform: ( + response: BedrockAI21CompleteResponse | BedrockErrorResponse, responseStatus: number, responseHeaders: Headers ) => ChatCompletionResponse | ErrorResponse = ( @@ -1220,27 +727,25 @@ export const BedrockMistralChatCompleteResponseTransform: ( if (errorResposne) return errorResposne; } - if ('outputs' in response) { + if ('completions' in response) { const prompt_tokens = Number(responseHeaders.get('X-Amzn-Bedrock-Input-Token-Count')) || 0; const completion_tokens = Number(responseHeaders.get('X-Amzn-Bedrock-Output-Token-Count')) || 0; return { - id: Date.now().toString(), + id: response.id.toString(), object: 'chat.completion', created: Math.floor(Date.now() / 1000), model: '', provider: BEDROCK, - choices: [ - { - index: 0, - message: { - role: 'assistant', - content: response.outputs[0].text, - }, - finish_reason: response.outputs[0].stop_reason, + choices: response.completions.map((completion, index) => ({ + index: index, + message: { + role: 'assistant', + content: completion.data.text, }, - ], + finish_reason: completion.finishReason?.reason, + })), usage: { prompt_tokens: prompt_tokens, completion_tokens: completion_tokens, @@ -1251,61 +756,3 @@ export const BedrockMistralChatCompleteResponseTransform: ( return generateInvalidProviderResponseError(response, BEDROCK); }; - -export const BedrockMistralChatCompleteStreamChunkTransform: ( - response: string, - fallbackId: string -) => string | string[] = (responseChunk, fallbackId) => { - let chunk = responseChunk.trim(); - chunk = chunk.replace(/^data: /, ''); - chunk = chunk.trim(); - const parsedChunk: BedrocMistralStreamChunk = JSON.parse(chunk); - - // discard the last cohere chunk as it sends the whole response combined. - if (parsedChunk.outputs[0].stop_reason) { - return [ - `data: ${JSON.stringify({ - id: fallbackId, - object: 'chat.completion.chunk', - created: Math.floor(Date.now() / 1000), - model: '', - provider: BEDROCK, - choices: [ - { - index: 0, - delta: {}, - finish_reason: parsedChunk.outputs[0].stop_reason, - }, - ], - usage: { - prompt_tokens: - parsedChunk['amazon-bedrock-invocationMetrics'].inputTokenCount, - completion_tokens: - parsedChunk['amazon-bedrock-invocationMetrics'].outputTokenCount, - total_tokens: - parsedChunk['amazon-bedrock-invocationMetrics'].inputTokenCount + - parsedChunk['amazon-bedrock-invocationMetrics'].outputTokenCount, - }, - })}\n\n`, - `data: [DONE]\n\n`, - ]; - } - - return `data: ${JSON.stringify({ - id: fallbackId, - object: 'chat.completion.chunk', - created: Math.floor(Date.now() / 1000), - model: '', - provider: BEDROCK, - choices: [ - { - index: 0, - delta: { - role: 'assistant', - content: parsedChunk.outputs[0].text, - }, - finish_reason: null, - }, - ], - })}\n\n`; -}; diff --git a/src/providers/bedrock/constants.ts b/src/providers/bedrock/constants.ts index 95b157acc..d90bd82e7 100644 --- a/src/providers/bedrock/constants.ts +++ b/src/providers/bedrock/constants.ts @@ -39,3 +39,10 @@ export const BEDROCK_STABILITY_V1_MODELS = [ 'stable-diffusion-xl-v0', 'stable-diffusion-xl-v1', ]; + +export const bedrockInvokeModels = [ + 'cohere.command-light-text-v14', + 'cohere.command-text-v14', + 'ai21.j2-mid-v1', + 'ai21.j2-ultra-v1', +]; diff --git a/src/providers/bedrock/index.ts b/src/providers/bedrock/index.ts index cb68ac707..3a5c6adc3 100644 --- a/src/providers/bedrock/index.ts +++ b/src/providers/bedrock/index.ts @@ -4,24 +4,14 @@ import { Params } from '../../types/requestBody'; import { ProviderConfigs } from '../types'; import BedrockAPIConfig from './api'; import { - BedrockAI21ChatCompleteConfig, - BedrockAI21ChatCompleteResponseTransform, - BedrockAnthropicChatCompleteConfig, - BedrockAnthropicChatCompleteResponseTransform, - BedrockAnthropicChatCompleteStreamChunkTransform, + BedrockConverseChatCompleteConfig, + BedrockChatCompleteStreamChunkTransform, + BedrockChatCompleteResponseTransform, BedrockCohereChatCompleteConfig, - BedrockCohereChatCompleteResponseTransform, BedrockCohereChatCompleteStreamChunkTransform, - BedrockLlamaChatCompleteResponseTransform, - BedrockLlamaChatCompleteStreamChunkTransform, - BedrockTitanChatCompleteResponseTransform, - BedrockTitanChatCompleteStreamChunkTransform, - BedrockTitanChatompleteConfig, - BedrockMistralChatCompleteConfig, - BedrockMistralChatCompleteResponseTransform, - BedrockMistralChatCompleteStreamChunkTransform, - BedrockLlama3ChatCompleteConfig, - BedrockLlama2ChatCompleteConfig, + BedrockCohereChatCompleteResponseTransform, + BedrockAI21ChatCompleteConfig, + BedrockAI21ChatCompleteResponseTransform, } from './chatComplete'; import { BedrockAI21CompleteConfig, @@ -62,95 +52,87 @@ const BedrockConfig: ProviderConfigs = { if (!params.model) { throw new GatewayError('Bedrock model not found'); } - - // To remove the region in case its a cross-region inference profile ID - // https://docs.aws.amazon.com/bedrock/latest/userguide/cross-region-inference-support.html const providerModel = params.model.replace(/^(us\.|eu\.)/, ''); const providerModelArray = providerModel.split('.'); const provider = providerModelArray[0]; const model = providerModelArray.slice(1).join('.'); + let config: ProviderConfigs = {}; switch (provider) { case ANTHROPIC: - return { + config = { complete: BedrockAnthropicCompleteConfig, - chatComplete: BedrockAnthropicChatCompleteConfig, api: BedrockAPIConfig, responseTransforms: { 'stream-complete': BedrockAnthropicCompleteStreamChunkTransform, complete: BedrockAnthropicCompleteResponseTransform, - 'stream-chatComplete': - BedrockAnthropicChatCompleteStreamChunkTransform, - chatComplete: BedrockAnthropicChatCompleteResponseTransform, }, }; + break; case COHERE: - return { + config = { complete: BedrockCohereCompleteConfig, - chatComplete: BedrockCohereChatCompleteConfig, embed: BedrockCohereEmbedConfig, api: BedrockAPIConfig, responseTransforms: { 'stream-complete': BedrockCohereCompleteStreamChunkTransform, complete: BedrockCohereCompleteResponseTransform, - 'stream-chatComplete': - BedrockCohereChatCompleteStreamChunkTransform, - chatComplete: BedrockCohereChatCompleteResponseTransform, embed: BedrockCohereEmbedResponseTransform, }, }; + if (['command-text-v14', 'command-light-text-v14'].includes(model)) { + config.chatComplete = BedrockCohereChatCompleteConfig; + config.responseTransforms['stream-chatComplete'] = + BedrockCohereChatCompleteStreamChunkTransform; + config.responseTransforms.chatComplete = + BedrockCohereChatCompleteResponseTransform; + } + break; case 'meta': - const chatCompleteConfig = - model?.search('llama3') === -1 - ? BedrockLlama2ChatCompleteConfig - : BedrockLlama3ChatCompleteConfig; - return { + config = { complete: BedrockLLamaCompleteConfig, - chatComplete: chatCompleteConfig, api: BedrockAPIConfig, responseTransforms: { 'stream-complete': BedrockLlamaCompleteStreamChunkTransform, complete: BedrockLlamaCompleteResponseTransform, - 'stream-chatComplete': BedrockLlamaChatCompleteStreamChunkTransform, - chatComplete: BedrockLlamaChatCompleteResponseTransform, }, }; + break; case 'mistral': - return { + config = { complete: BedrockMistralCompleteConfig, - chatComplete: BedrockMistralChatCompleteConfig, api: BedrockAPIConfig, responseTransforms: { 'stream-complete': BedrockMistralCompleteStreamChunkTransform, complete: BedrockMistralCompleteResponseTransform, - 'stream-chatComplete': - BedrockMistralChatCompleteStreamChunkTransform, - chatComplete: BedrockMistralChatCompleteResponseTransform, }, }; + break; case 'amazon': - return { + config = { complete: BedrockTitanCompleteConfig, - chatComplete: BedrockTitanChatompleteConfig, embed: BedrockTitanEmbedConfig, api: BedrockAPIConfig, responseTransforms: { 'stream-complete': BedrockTitanCompleteStreamChunkTransform, complete: BedrockTitanCompleteResponseTransform, - 'stream-chatComplete': BedrockTitanChatCompleteStreamChunkTransform, - chatComplete: BedrockTitanChatCompleteResponseTransform, embed: BedrockTitanEmbedResponseTransform, }, }; + break; case AI21: - return { + config = { complete: BedrockAI21CompleteConfig, - chatComplete: BedrockAI21ChatCompleteConfig, api: BedrockAPIConfig, responseTransforms: { complete: BedrockAI21CompleteResponseTransform, - chatComplete: BedrockAI21ChatCompleteResponseTransform, }, }; + if (['j2-mid-v1', 'j2-ultra-v1'].includes(model)) { + config.chatComplete = BedrockAI21ChatCompleteConfig; + config.responseTransforms.chatComplete = + BedrockAI21ChatCompleteResponseTransform; + } + break; case 'stability': if (model && BEDROCK_STABILITY_V1_MODELS.includes(model)) { return { @@ -168,9 +150,20 @@ const BedrockConfig: ProviderConfigs = { imageGenerate: BedrockStabilityAIImageGenerateV2ResponseTransform, }, }; - default: - throw new GatewayError('Invalid bedrock provider'); + break; + } + if (!config.chatComplete) { + config.chatComplete = BedrockConverseChatCompleteConfig; + } + if (!config.responseTransforms['stream-chatComplete']) { + config.responseTransforms['stream-chatComplete'] = + BedrockChatCompleteStreamChunkTransform; + } + if (!config.responseTransforms.chatComplete) { + config.responseTransforms.chatComplete = + BedrockChatCompleteResponseTransform; } + return config; }, }; diff --git a/src/providers/bedrock/utils.ts b/src/providers/bedrock/utils.ts index 7406e87e9..ede1d6fc1 100644 --- a/src/providers/bedrock/utils.ts +++ b/src/providers/bedrock/utils.ts @@ -1,11 +1,6 @@ import { SignatureV4 } from '@smithy/signature-v4'; import { Sha256 } from '@aws-crypto/sha256-js'; -import { ContentType, Message, MESSAGE_ROLES } from '../../types/requestBody'; -import { - LLAMA_2_SPECIAL_TOKENS, - LLAMA_3_SPECIAL_TOKENS, - MISTRAL_CONTROL_TOKENS, -} from './constants'; +import { BedrockChatCompletionsParams } from './chatComplete'; export const generateAWSHeaders = async ( body: Record, @@ -45,99 +40,55 @@ export const generateAWSHeaders = async ( return signed.headers; }; -/* - Helper function to use inside reduce to convert ContentType array to string -*/ -const convertContentTypesToString = (acc: string, curr: ContentType) => { - if (curr.type !== 'text') return acc; - acc += curr.text + '\n'; - return acc; -}; - -/* - Handle messages of both string and ContentType array -*/ -const getMessageContent = (message: Message) => { - if (message === undefined) return ''; - if (typeof message.content === 'object') { - return message.content.reduce(convertContentTypesToString, ''); +export const transformInferenceConfig = ( + params: BedrockChatCompletionsParams +) => { + const inferenceConfig: Record = {}; + if (params['max_tokens'] || params['max_completion_tokens']) { + inferenceConfig['maxTokens'] = + params['max_tokens'] || params['max_completion_tokens']; } - return message.content || ''; -}; - -/* - This function transforms the messages for the LLama 3.1 prompt. - It adds the special tokens to the beginning and end of the prompt. - refer: https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_1 - NOTE: Portkey does not restrict messages to alternate user and assistant roles, this is to support more flexible use cases. -*/ -export const transformMessagesForLLama3Prompt = (messages: Message[]) => { - let prompt: string = ''; - prompt += LLAMA_3_SPECIAL_TOKENS.PROMPT_START + '\n'; - messages.forEach((msg, index) => { - prompt += - LLAMA_3_SPECIAL_TOKENS.ROLE_START + - msg.role + - LLAMA_3_SPECIAL_TOKENS.ROLE_END + - '\n'; - prompt += getMessageContent(msg) + LLAMA_3_SPECIAL_TOKENS.END_OF_TURN; - }); - prompt += - LLAMA_3_SPECIAL_TOKENS.ROLE_START + - MESSAGE_ROLES.ASSISTANT + - LLAMA_3_SPECIAL_TOKENS.ROLE_END + - '\n'; - return prompt; -}; - -/* - This function transforms the messages for the LLama 2 prompt. - It combines the system message with the first user message, - and then attaches the message pairs. - Finally, it adds the last message to the prompt. - refer: https://github.com/meta-llama/llama/blob/main/llama/generation.py#L284-L395 -*/ -export const transformMessagesForLLama2Prompt = (messages: Message[]) => { - let finalPrompt: string = ''; - // combine system message with first user message - if (messages.length > 0 && messages[0].role === MESSAGE_ROLES.SYSTEM) { - messages[0].content = - LLAMA_2_SPECIAL_TOKENS.SYSTEM_MESSAGE_START + - getMessageContent(messages[0]) + - LLAMA_2_SPECIAL_TOKENS.SYSTEM_MESSAGE_END + - getMessageContent(messages[1]); + if (params['stop']) { + inferenceConfig['stopSequences'] = params['stop']; } - messages = [messages[0], ...messages.slice(2)]; - // attach message pairs - for (let i = 1; i < messages.length; i += 2) { - let prompt = getMessageContent(messages[i - 1]); - let answer = getMessageContent(messages[i]); - finalPrompt += `${LLAMA_2_SPECIAL_TOKENS.BEGINNING_OF_SENTENCE}${LLAMA_2_SPECIAL_TOKENS.CONVERSATION_TURN_START} ${prompt} ${LLAMA_2_SPECIAL_TOKENS.CONVERSATION_TURN_END} ${answer} ${LLAMA_2_SPECIAL_TOKENS.END_OF_SENTENCE}`; + if (params['temperature']) { + inferenceConfig['temperature'] = params['temperature']; } - if (messages.length % 2 === 1) { - finalPrompt += `${LLAMA_2_SPECIAL_TOKENS.BEGINNING_OF_SENTENCE}${LLAMA_2_SPECIAL_TOKENS.CONVERSATION_TURN_START} ${getMessageContent(messages[messages.length - 1])} ${LLAMA_2_SPECIAL_TOKENS.CONVERSATION_TURN_END}`; + if (params['top_p']) { + inferenceConfig['topP'] = params['top_p']; } - return finalPrompt; + return inferenceConfig; }; -/* -refer: https://docs.mistral.ai/guides/tokenization/ -refer: https://github.com/chujiezheng/chat_templates/blob/main/chat_templates/mistral-instruct.jinja -*/ -export const transformMessagesForMistralPrompt = (messages: Message[]) => { - let finalPrompt: string = `${MISTRAL_CONTROL_TOKENS.BEGINNING_OF_SENTENCE}`; - // Mistral does not support system messages. (ref: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.3/discussions/14) - if (messages.length > 0 && messages[0].role === MESSAGE_ROLES.SYSTEM) { - messages[0].content = - getMessageContent(messages[0]) + '\n' + getMessageContent(messages[1]); - messages[0].role = MESSAGE_ROLES.USER; +export const transformAdditionalModelRequestFields = ( + params: BedrockChatCompletionsParams +) => { + const additionalModelRequestFields: Record = + params.additionalModelRequestFields || {}; + if (params['top_k']) { + additionalModelRequestFields['topK'] = params['top_k']; + } + // Backward compatibility + if (params['anthropic_version']) { + additionalModelRequestFields['anthropic_version'] = + params['anthropic_version']; + } + if (params['frequency_penalty']) { + additionalModelRequestFields['frequencyPenalty'] = + params['frequency_penalty']; + } + if (params['presence_penalty']) { + additionalModelRequestFields['presencePenalty'] = + params['presence_penalty']; + } + if (params['logit_bias']) { + additionalModelRequestFields['logitBias'] = params['logit_bias']; + } + if (params['n']) { + additionalModelRequestFields['n'] = params['n']; } - for (const message of messages) { - if (message.role === MESSAGE_ROLES.USER) { - finalPrompt += `${MISTRAL_CONTROL_TOKENS.CONVERSATION_TURN_START} ${message.content} ${MISTRAL_CONTROL_TOKENS.CONVERSATION_TURN_END}`; - } else { - finalPrompt += ` ${message.content} ${MISTRAL_CONTROL_TOKENS.END_OF_SENTENCE}`; - } + if (params['countPenalty']) { + additionalModelRequestFields['countPenalty'] = params['countPenalty']; } - return finalPrompt; + return additionalModelRequestFields; }; diff --git a/src/providers/google-vertex-ai/utils.ts b/src/providers/google-vertex-ai/utils.ts index 1eec76ba6..adc897228 100644 --- a/src/providers/google-vertex-ai/utils.ts +++ b/src/providers/google-vertex-ai/utils.ts @@ -1,6 +1,6 @@ import { GoogleErrorResponse } from './types'; import { generateErrorResponse } from '../utils'; -import { GOOGLE_VERTEX_AI } from '../../globals'; +import { fileExtensionMimeTypeMap, GOOGLE_VERTEX_AI } from '../../globals'; import { ErrorResponse } from '../types'; /** @@ -139,27 +139,6 @@ export const getModelAndProvider = (modelString: string) => { return { provider, model }; }; -const fileExtensionMimeTypeMap = { - mp4: 'video/mp4', - jpeg: 'image/jpeg', - jpg: 'image/jpeg', - png: 'image/png', - bmp: 'image/bmp', - tiff: 'image/tiff', - webp: 'image/webp', - pdf: 'application/pdf', - mp3: 'audio/mp3', - wav: 'audio/wav', - txt: 'text/plain', - mov: 'video/mov', - mpeg: 'video/mpeg', - mpg: 'video/mpg', - avi: 'video/avi', - wmv: 'video/wmv', - mpegps: 'video/mpegps', - flv: 'video/flv', -}; - export const getMimeType = (url: string): string | undefined => { const urlParts = url.split('.'); const extension = urlParts[ diff --git a/src/providers/types.ts b/src/providers/types.ts index 3256658bd..552abe38f 100644 --- a/src/providers/types.ts +++ b/src/providers/types.ts @@ -139,6 +139,7 @@ export interface ChatChoice { */ export interface ChatCompletionResponse extends CResponse { choices: ChatChoice[]; + provider?: string; } /** From 33f1bfaec4f65fe02854bccfac502784c519c005 Mon Sep 17 00:00:00 2001 From: Narendranath Gogineni Date: Tue, 22 Oct 2024 21:19:15 +0530 Subject: [PATCH 2/7] add comments --- src/providers/bedrock/index.ts | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/providers/bedrock/index.ts b/src/providers/bedrock/index.ts index 3a5c6adc3..930230bb7 100644 --- a/src/providers/bedrock/index.ts +++ b/src/providers/bedrock/index.ts @@ -52,6 +52,9 @@ const BedrockConfig: ProviderConfigs = { if (!params.model) { throw new GatewayError('Bedrock model not found'); } + + // To remove the region in case its a cross-region inference profile ID + // https://docs.aws.amazon.com/bedrock/latest/userguide/cross-region-inference-support.html const providerModel = params.model.replace(/^(us\.|eu\.)/, ''); const providerModelArray = providerModel.split('.'); const provider = providerModelArray[0]; From 913a414f3e3b09c230f05ec60f2193b42f8f1169 Mon Sep 17 00:00:00 2001 From: Narendranath Gogineni Date: Tue, 22 Oct 2024 21:30:20 +0530 Subject: [PATCH 3/7] add support for max_completion_tokens --- src/providers/bedrock/chatComplete.ts | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/providers/bedrock/chatComplete.ts b/src/providers/bedrock/chatComplete.ts index 16038d411..78eccd98c 100644 --- a/src/providers/bedrock/chatComplete.ts +++ b/src/providers/bedrock/chatComplete.ts @@ -202,6 +202,11 @@ export const BedrockConverseChatCompleteConfig: ProviderConfig = { transform: (params: BedrockChatCompletionsParams) => transformInferenceConfig(params), }, + max_completion_tokens: { + param: 'inferenceConfig', + transform: (params: BedrockChatCompletionsParams) => + transformInferenceConfig(params), + }, stop: { param: 'inferenceConfig', transform: (params: BedrockChatCompletionsParams) => From d5e797e72683753bc0f2a68832ec8dae63996316 Mon Sep 17 00:00:00 2001 From: Narendranath Gogineni Date: Tue, 22 Oct 2024 23:00:55 +0530 Subject: [PATCH 4/7] Restrict known additional model fields to the specific providers --- src/providers/bedrock/chatComplete.ts | 163 ++++++++++++++++++++------ src/providers/bedrock/index.ts | 6 + src/providers/bedrock/utils.ts | 76 ++++++++++-- 3 files changed, 203 insertions(+), 42 deletions(-) diff --git a/src/providers/bedrock/chatComplete.ts b/src/providers/bedrock/chatComplete.ts index 78eccd98c..7896c0178 100644 --- a/src/providers/bedrock/chatComplete.ts +++ b/src/providers/bedrock/chatComplete.ts @@ -17,6 +17,9 @@ import { import { BedrockErrorResponse } from './embed'; import { transformAdditionalModelRequestFields, + transformAI21AdditionalModelRequestFields, + transformAnthropicAdditionalModelRequestFields, + transformCohereAdditionalModelRequestFields, transformInferenceConfig, } from './utils'; @@ -32,6 +35,29 @@ export interface BedrockChatCompletionsParams extends Params { countPenalty?: number; } +export interface BedrockConverseAnthropicChatCompletionsParams + extends BedrockChatCompletionsParams { + anthropic_version?: string; + user?: string; +} + +export interface BedrockConverseCohereChatCompletionsParams + extends BedrockChatCompletionsParams { + frequency_penalty?: number; + presence_penalty?: number; + logit_bias?: Record; + n?: number; +} + +export interface BedrockConverseAI21ChatCompletionsParams + extends BedrockChatCompletionsParams { + frequency_penalty?: number; + presence_penalty?: number; + frequencyPenalty?: number; + presencePenalty?: number; + countPenalty?: number; +} + const getMessageTextContentArray = (message: Message): { text: string }[] => { if (message.content && typeof message.content === 'object') { return message.content @@ -232,41 +258,6 @@ export const BedrockConverseChatCompleteConfig: ProviderConfig = { transform: (params: BedrockChatCompletionsParams) => transformAdditionalModelRequestFields(params), }, - anthropic_version: { - param: 'additionalModelRequestFields', - transform: (params: BedrockChatCompletionsParams) => - transformAdditionalModelRequestFields(params), - }, - frequency_penalty: { - param: 'additionalModelRequestFields', - transform: (params: BedrockChatCompletionsParams) => - transformAdditionalModelRequestFields(params), - }, - presence_penalty: { - param: 'additionalModelRequestFields', - transform: (params: BedrockChatCompletionsParams) => - transformAdditionalModelRequestFields(params), - }, - logit_bias: { - param: 'additionalModelRequestFields', - transform: (params: BedrockChatCompletionsParams) => - transformAdditionalModelRequestFields(params), - }, - n: { - param: 'additionalModelRequestFields', - transform: (params: BedrockChatCompletionsParams) => - transformAdditionalModelRequestFields(params), - }, - stream: { - param: 'additionalModelRequestFields', - transform: (params: BedrockChatCompletionsParams) => - transformAdditionalModelRequestFields(params), - }, - countPenalty: { - param: 'additionalModelRequestFields', - transform: (params: BedrockChatCompletionsParams) => - transformAdditionalModelRequestFields(params), - }, }; interface BedrockChatCompletionResponse { @@ -460,6 +451,108 @@ export const BedrockChatCompleteStreamChunkTransform: ( })}\n\n`; }; +export const BedrockConverseAnthropicChatCompleteConfig: ProviderConfig = { + ...BedrockConverseChatCompleteConfig, + additionalModelRequestFields: { + param: 'additionalModelRequestFields', + transform: (params: BedrockChatCompletionsParams) => + transformAnthropicAdditionalModelRequestFields(params), + }, + top_k: { + param: 'additionalModelRequestFields', + transform: (params: BedrockChatCompletionsParams) => + transformAnthropicAdditionalModelRequestFields(params), + }, + anthropic_version: { + param: 'additionalModelRequestFields', + transform: (params: BedrockChatCompletionsParams) => + transformAnthropicAdditionalModelRequestFields(params), + }, + user: { + param: 'user', + transform: (params: BedrockChatCompletionsParams) => + transformAnthropicAdditionalModelRequestFields(params), + }, +}; + +export const BedrockConverseCohereChatCompleteConfig: ProviderConfig = { + ...BedrockConverseChatCompleteConfig, + additionalModelRequestFields: { + param: 'additionalModelRequestFields', + transform: (params: BedrockChatCompletionsParams) => + transformCohereAdditionalModelRequestFields(params), + }, + top_k: { + param: 'additionalModelRequestFields', + transform: (params: BedrockChatCompletionsParams) => + transformCohereAdditionalModelRequestFields(params), + }, + frequency_penalty: { + param: 'additionalModelRequestFields', + transform: (params: BedrockChatCompletionsParams) => + transformCohereAdditionalModelRequestFields(params), + }, + presence_penalty: { + param: 'additionalModelRequestFields', + transform: (params: BedrockChatCompletionsParams) => + transformCohereAdditionalModelRequestFields(params), + }, + logit_bias: { + param: 'additionalModelRequestFields', + transform: (params: BedrockChatCompletionsParams) => + transformCohereAdditionalModelRequestFields(params), + }, + n: { + param: 'additionalModelRequestFields', + transform: (params: BedrockChatCompletionsParams) => + transformCohereAdditionalModelRequestFields(params), + }, + stream: { + param: 'additionalModelRequestFields', + transform: (params: BedrockChatCompletionsParams) => + transformCohereAdditionalModelRequestFields(params), + }, +}; + +export const BedrockConverseAI21ChatCompleteConfig: ProviderConfig = { + ...BedrockConverseChatCompleteConfig, + additionalModelRequestFields: { + param: 'additionalModelRequestFields', + transform: (params: BedrockChatCompletionsParams) => + transformAI21AdditionalModelRequestFields(params), + }, + top_k: { + param: 'additionalModelRequestFields', + transform: (params: BedrockChatCompletionsParams) => + transformAI21AdditionalModelRequestFields(params), + }, + frequency_penalty: { + param: 'additionalModelRequestFields', + transform: (params: BedrockChatCompletionsParams) => + transformAI21AdditionalModelRequestFields(params), + }, + presence_penalty: { + param: 'additionalModelRequestFields', + transform: (params: BedrockChatCompletionsParams) => + transformAI21AdditionalModelRequestFields(params), + }, + frequencyPenalty: { + param: 'additionalModelRequestFields', + transform: (params: BedrockChatCompletionsParams) => + transformAI21AdditionalModelRequestFields(params), + }, + presencePenalty: { + param: 'additionalModelRequestFields', + transform: (params: BedrockChatCompletionsParams) => + transformAI21AdditionalModelRequestFields(params), + }, + countPenalty: { + param: 'additionalModelRequestFields', + transform: (params: BedrockChatCompletionsParams) => + transformAI21AdditionalModelRequestFields(params), + }, +}; + export const BedrockCohereChatCompleteConfig: ProviderConfig = { messages: { param: 'prompt', diff --git a/src/providers/bedrock/index.ts b/src/providers/bedrock/index.ts index 930230bb7..523d010e0 100644 --- a/src/providers/bedrock/index.ts +++ b/src/providers/bedrock/index.ts @@ -12,6 +12,9 @@ import { BedrockCohereChatCompleteResponseTransform, BedrockAI21ChatCompleteConfig, BedrockAI21ChatCompleteResponseTransform, + BedrockConverseAnthropicChatCompleteConfig, + BedrockConverseCohereChatCompleteConfig, + BedrockConverseAI21ChatCompleteConfig, } from './chatComplete'; import { BedrockAI21CompleteConfig, @@ -64,6 +67,7 @@ const BedrockConfig: ProviderConfigs = { case ANTHROPIC: config = { complete: BedrockAnthropicCompleteConfig, + chatComplete: BedrockConverseAnthropicChatCompleteConfig, api: BedrockAPIConfig, responseTransforms: { 'stream-complete': BedrockAnthropicCompleteStreamChunkTransform, @@ -74,6 +78,7 @@ const BedrockConfig: ProviderConfigs = { case COHERE: config = { complete: BedrockCohereCompleteConfig, + chatComplete: BedrockConverseCohereChatCompleteConfig, embed: BedrockCohereEmbedConfig, api: BedrockAPIConfig, responseTransforms: { @@ -126,6 +131,7 @@ const BedrockConfig: ProviderConfigs = { config = { complete: BedrockAI21CompleteConfig, api: BedrockAPIConfig, + chatComplete: BedrockConverseAI21ChatCompleteConfig, responseTransforms: { complete: BedrockAI21CompleteResponseTransform, }, diff --git a/src/providers/bedrock/utils.ts b/src/providers/bedrock/utils.ts index ede1d6fc1..40d2606d9 100644 --- a/src/providers/bedrock/utils.ts +++ b/src/providers/bedrock/utils.ts @@ -1,6 +1,11 @@ import { SignatureV4 } from '@smithy/signature-v4'; import { Sha256 } from '@aws-crypto/sha256-js'; -import { BedrockChatCompletionsParams } from './chatComplete'; +import { + BedrockConverseAI21ChatCompletionsParams, + BedrockConverseAnthropicChatCompletionsParams, + BedrockChatCompletionsParams, + BedrockConverseCohereChatCompletionsParams, +} from './chatComplete'; export const generateAWSHeaders = async ( body: Record, @@ -66,26 +71,83 @@ export const transformAdditionalModelRequestFields = ( const additionalModelRequestFields: Record = params.additionalModelRequestFields || {}; if (params['top_k']) { - additionalModelRequestFields['topK'] = params['top_k']; + additionalModelRequestFields['top_k'] = params['top_k']; + } + return additionalModelRequestFields; +}; + +export const transformAnthropicAdditionalModelRequestFields = ( + params: BedrockConverseAnthropicChatCompletionsParams +) => { + const additionalModelRequestFields: Record = + params.additionalModelRequestFields || {}; + if (params['top_k']) { + additionalModelRequestFields['top_k'] = params['top_k']; } - // Backward compatibility if (params['anthropic_version']) { additionalModelRequestFields['anthropic_version'] = params['anthropic_version']; } + if (params['user']) { + additionalModelRequestFields['metadata'] = { + user_id: params['user'], + }; + } + return additionalModelRequestFields; +}; + +export const transformCohereAdditionalModelRequestFields = ( + params: BedrockConverseCohereChatCompletionsParams +) => { + const additionalModelRequestFields: Record = + params.additionalModelRequestFields || {}; + if (params['top_k']) { + additionalModelRequestFields['top_k'] = params['top_k']; + } + if (params['n']) { + additionalModelRequestFields['n'] = params['n']; + } if (params['frequency_penalty']) { - additionalModelRequestFields['frequencyPenalty'] = + additionalModelRequestFields['frequency_penalty'] = params['frequency_penalty']; } if (params['presence_penalty']) { - additionalModelRequestFields['presencePenalty'] = + additionalModelRequestFields['presence_penalty'] = params['presence_penalty']; } if (params['logit_bias']) { additionalModelRequestFields['logitBias'] = params['logit_bias']; } - if (params['n']) { - additionalModelRequestFields['n'] = params['n']; + if (params['stream']) { + additionalModelRequestFields['stream'] = params['stream']; + } + return additionalModelRequestFields; +}; + +export const transformAI21AdditionalModelRequestFields = ( + params: BedrockConverseAI21ChatCompletionsParams +) => { + const additionalModelRequestFields: Record = + params.additionalModelRequestFields || {}; + if (params['top_k']) { + additionalModelRequestFields['top_k'] = params['top_k']; + } + if (params['frequency_penalty']) { + additionalModelRequestFields['frequencyPenalty'] = { + scale: params['frequency_penalty'], + }; + } + if (params['presence_penalty']) { + additionalModelRequestFields['presencePenalty'] = { + scale: params['presence_penalty'], + }; + } + if (params['frequencyPenalty']) { + additionalModelRequestFields['frequencyPenalty'] = + params['frequencyPenalty']; + } + if (params['presencePenalty']) { + additionalModelRequestFields['presencePenalty'] = params['presencePenalty']; } if (params['countPenalty']) { additionalModelRequestFields['countPenalty'] = params['countPenalty']; From f13d1cfa0e811bf3bb215438bcfba2dc846ad63b Mon Sep 17 00:00:00 2001 From: Narendranath Gogineni Date: Tue, 22 Oct 2024 23:05:31 +0530 Subject: [PATCH 5/7] Fix interfaces --- src/providers/bedrock/chatComplete.ts | 36 +++++++++++++-------------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/src/providers/bedrock/chatComplete.ts b/src/providers/bedrock/chatComplete.ts index 7896c0178..65a161b30 100644 --- a/src/providers/bedrock/chatComplete.ts +++ b/src/providers/bedrock/chatComplete.ts @@ -455,22 +455,22 @@ export const BedrockConverseAnthropicChatCompleteConfig: ProviderConfig = { ...BedrockConverseChatCompleteConfig, additionalModelRequestFields: { param: 'additionalModelRequestFields', - transform: (params: BedrockChatCompletionsParams) => + transform: (params: BedrockConverseAnthropicChatCompletionsParams) => transformAnthropicAdditionalModelRequestFields(params), }, top_k: { param: 'additionalModelRequestFields', - transform: (params: BedrockChatCompletionsParams) => + transform: (params: BedrockConverseAnthropicChatCompletionsParams) => transformAnthropicAdditionalModelRequestFields(params), }, anthropic_version: { param: 'additionalModelRequestFields', - transform: (params: BedrockChatCompletionsParams) => + transform: (params: BedrockConverseAnthropicChatCompletionsParams) => transformAnthropicAdditionalModelRequestFields(params), }, user: { param: 'user', - transform: (params: BedrockChatCompletionsParams) => + transform: (params: BedrockConverseAnthropicChatCompletionsParams) => transformAnthropicAdditionalModelRequestFields(params), }, }; @@ -479,37 +479,37 @@ export const BedrockConverseCohereChatCompleteConfig: ProviderConfig = { ...BedrockConverseChatCompleteConfig, additionalModelRequestFields: { param: 'additionalModelRequestFields', - transform: (params: BedrockChatCompletionsParams) => + transform: (params: BedrockConverseCohereChatCompletionsParams) => transformCohereAdditionalModelRequestFields(params), }, top_k: { param: 'additionalModelRequestFields', - transform: (params: BedrockChatCompletionsParams) => + transform: (params: BedrockConverseCohereChatCompletionsParams) => transformCohereAdditionalModelRequestFields(params), }, frequency_penalty: { param: 'additionalModelRequestFields', - transform: (params: BedrockChatCompletionsParams) => + transform: (params: BedrockConverseCohereChatCompletionsParams) => transformCohereAdditionalModelRequestFields(params), }, presence_penalty: { param: 'additionalModelRequestFields', - transform: (params: BedrockChatCompletionsParams) => + transform: (params: BedrockConverseCohereChatCompletionsParams) => transformCohereAdditionalModelRequestFields(params), }, logit_bias: { param: 'additionalModelRequestFields', - transform: (params: BedrockChatCompletionsParams) => + transform: (params: BedrockConverseCohereChatCompletionsParams) => transformCohereAdditionalModelRequestFields(params), }, n: { param: 'additionalModelRequestFields', - transform: (params: BedrockChatCompletionsParams) => + transform: (params: BedrockConverseCohereChatCompletionsParams) => transformCohereAdditionalModelRequestFields(params), }, stream: { param: 'additionalModelRequestFields', - transform: (params: BedrockChatCompletionsParams) => + transform: (params: BedrockConverseCohereChatCompletionsParams) => transformCohereAdditionalModelRequestFields(params), }, }; @@ -518,37 +518,37 @@ export const BedrockConverseAI21ChatCompleteConfig: ProviderConfig = { ...BedrockConverseChatCompleteConfig, additionalModelRequestFields: { param: 'additionalModelRequestFields', - transform: (params: BedrockChatCompletionsParams) => + transform: (params: BedrockConverseAI21ChatCompletionsParams) => transformAI21AdditionalModelRequestFields(params), }, top_k: { param: 'additionalModelRequestFields', - transform: (params: BedrockChatCompletionsParams) => + transform: (params: BedrockConverseAI21ChatCompletionsParams) => transformAI21AdditionalModelRequestFields(params), }, frequency_penalty: { param: 'additionalModelRequestFields', - transform: (params: BedrockChatCompletionsParams) => + transform: (params: BedrockConverseAI21ChatCompletionsParams) => transformAI21AdditionalModelRequestFields(params), }, presence_penalty: { param: 'additionalModelRequestFields', - transform: (params: BedrockChatCompletionsParams) => + transform: (params: BedrockConverseAI21ChatCompletionsParams) => transformAI21AdditionalModelRequestFields(params), }, frequencyPenalty: { param: 'additionalModelRequestFields', - transform: (params: BedrockChatCompletionsParams) => + transform: (params: BedrockConverseAI21ChatCompletionsParams) => transformAI21AdditionalModelRequestFields(params), }, presencePenalty: { param: 'additionalModelRequestFields', - transform: (params: BedrockChatCompletionsParams) => + transform: (params: BedrockConverseAI21ChatCompletionsParams) => transformAI21AdditionalModelRequestFields(params), }, countPenalty: { param: 'additionalModelRequestFields', - transform: (params: BedrockChatCompletionsParams) => + transform: (params: BedrockConverseAI21ChatCompletionsParams) => transformAI21AdditionalModelRequestFields(params), }, }; From 878ad9a0e96770f14f242bebc513ac72ebf7feb0 Mon Sep 17 00:00:00 2001 From: Narendranath Gogineni Date: Wed, 23 Oct 2024 00:38:54 +0530 Subject: [PATCH 6/7] fix key error --- src/providers/bedrock/chatComplete.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/providers/bedrock/chatComplete.ts b/src/providers/bedrock/chatComplete.ts index 65a161b30..a54719a0c 100644 --- a/src/providers/bedrock/chatComplete.ts +++ b/src/providers/bedrock/chatComplete.ts @@ -469,7 +469,7 @@ export const BedrockConverseAnthropicChatCompleteConfig: ProviderConfig = { transformAnthropicAdditionalModelRequestFields(params), }, user: { - param: 'user', + param: 'additionalModelRequestFields', transform: (params: BedrockConverseAnthropicChatCompletionsParams) => transformAnthropicAdditionalModelRequestFields(params), }, From b0e32c9754705235a8d63093c15192bccb698cea Mon Sep 17 00:00:00 2001 From: Narendranath Gogineni Date: Wed, 23 Oct 2024 15:36:39 +0530 Subject: [PATCH 7/7] make function optional --- src/providers/types.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/providers/types.ts b/src/providers/types.ts index 552abe38f..b0de165ad 100644 --- a/src/providers/types.ts +++ b/src/providers/types.ts @@ -50,7 +50,7 @@ export interface ProviderAPIConfig { gatewayRequestBody: Params; }) => string; /** A function to determine if the request body should be transformed to form data */ - transformToFormData: (args: { gatewayRequestBody: Params }) => boolean; + transformToFormData?: (args: { gatewayRequestBody: Params }) => boolean; } export type endpointStrings =