diff --git a/src/providers/bedrock/chatComplete.ts b/src/providers/bedrock/chatComplete.ts index 5605ca5a0..bb8d0bc31 100644 --- a/src/providers/bedrock/chatComplete.ts +++ b/src/providers/bedrock/chatComplete.ts @@ -11,6 +11,16 @@ import { } from '../utils'; import { BedrockErrorResponse } from './embed'; +interface BedrockChatCompletionsParams extends Params { + additionalModelRequestFields?: Record; + additionalModelResponseFieldPaths?: string[]; + guardrailConfig?: { + guardrailIdentifier: string; + guardrailVersion: string; + trace?: string; + }; +} + const getMessageTextContentArray = (message: Message): { text: string }[] => { if (message.content && typeof message.content === 'object') { return message.content @@ -75,7 +85,7 @@ export const BedrockConverseChatCompleteConfig: ProviderConfig = { { param: 'messages', required: true, - transform: (params: Params) => { + transform: (params: BedrockChatCompletionsParams) => { if (!params.messages) return []; return params.messages.map((msg) => { if (msg.role === 'system') return; @@ -89,7 +99,7 @@ export const BedrockConverseChatCompleteConfig: ProviderConfig = { { param: 'system', required: false, - transform: (params: Params) => { + transform: (params: BedrockChatCompletionsParams) => { if (!params.messages) return; const systemMessages = params.messages.reduce( (acc: { text: string }[], msg) => { @@ -105,7 +115,7 @@ export const BedrockConverseChatCompleteConfig: ProviderConfig = { }, { param: 'inferenceConfig', - transform: (params: Params) => { + transform: (params: BedrockChatCompletionsParams) => { return { maxTokens: params.max_tokens || params.max_completion_tokens, stopSequences: @@ -117,16 +127,17 @@ export const BedrockConverseChatCompleteConfig: ProviderConfig = { }, { param: 'additionalModelRequestFields', - transform: (params: Params) => { + transform: (params: BedrockChatCompletionsParams) => { return { topK: params.top_k, + ...params.additionalModelRequestFields, }; }, }, ], tools: { param: 'toolConfig', - transform: (params: Params) => { + transform: (params: BedrockChatCompletionsParams) => { const toolConfig = { tools: params.tools?.map((tool) => { if (!tool.function) return; @@ -139,28 +150,37 @@ export const BedrockConverseChatCompleteConfig: ProviderConfig = { }; }), }; + let toolChoice = undefined; if (params.tool_choice) { if (typeof params.tool_choice === 'object') { - return { + toolChoice = { tool: { name: params.tool_choice.function.name, }, }; - } - switch (params.tool_choice) { - case 'required': - return { + } else if (typeof params.tool_choice === 'string') { + if (params.tool_choice === 'required') { + toolChoice = { any: {}, }; - case 'auto': - return { + } else if (params.tool_choice === 'auto') { + toolChoice = { auto: {}, }; + } } } - return toolConfig; + return { ...toolConfig, toolChoice }; }, }, + guardrailConfig: { + param: 'guardrailConfig', + required: false, + }, + additionalModelResponseFieldPaths: { + param: 'additionalModelResponseFieldPaths', + required: false, + }, }; interface BedrockChatCompletionResponse {