From 74e2513c9c5280cb6d992e2d76cfaa4a9ad4a817 Mon Sep 17 00:00:00 2001 From: Mahesh Date: Wed, 29 Jan 2025 18:07:36 +0530 Subject: [PATCH] chore: allow azure-ai to accept the extra-params from the headers --- src/handlers/handlerUtils.ts | 1 + src/providers/azure-ai-inference/api.ts | 10 +++++++--- src/types/requestBody.ts | 1 + 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/src/handlers/handlerUtils.ts b/src/handlers/handlerUtils.ts index 2a10ecb5c..404305384 100644 --- a/src/handlers/handlerUtils.ts +++ b/src/handlers/handlerUtils.ts @@ -914,6 +914,7 @@ export function constructConfigFromRequestHeaders( requestHeaders[`x-${POWERED_BY}-azure-deployment-type`], azureApiVersion: requestHeaders[`x-${POWERED_BY}-azure-api-version`], azureEndpointName: requestHeaders[`x-${POWERED_BY}-azure-endpoint-name`], + azureExtraParams: requestHeaders[`x-${POWERED_BY}-azure-extra-params`], }; const awsConfig = { diff --git a/src/providers/azure-ai-inference/api.ts b/src/providers/azure-ai-inference/api.ts index 5693ef11d..699d44c20 100644 --- a/src/providers/azure-ai-inference/api.ts +++ b/src/providers/azure-ai-inference/api.ts @@ -20,11 +20,15 @@ const AzureAIInferenceAPI: ProviderAPIConfig = { return `https://${azureEndpointName}.${azureRegion}.inference.ml.azure.com/score`; }, headers: ({ providerOptions }) => { - const { apiKey, azureDeploymentType, azureDeploymentName } = - providerOptions; + const { + apiKey, + azureDeploymentType, + azureDeploymentName, + azureExtraParams, + } = providerOptions; const headers: Record = { Authorization: `Bearer ${apiKey}`, - 'extra-parameters': 'ignore', + 'extra-parameters': azureExtraParams || 'pass-through', }; if (azureDeploymentType === 'managed' && azureDeploymentName) { headers['azureml-model-deployment'] = azureDeploymentName; diff --git a/src/types/requestBody.ts b/src/types/requestBody.ts index 1c6af6b7e..f4eeac8a6 100644 --- a/src/types/requestBody.ts +++ b/src/types/requestBody.ts @@ -126,6 +126,7 @@ export interface Options { azureDeploymentType?: 'managed' | 'serverless'; azureEndpointName?: string; azureApiVersion?: string; + azureExtraParams?: string; /** The parameter to determine if extra non-openai compliant fields should be returned in response */ strictOpenAiCompliance?: boolean;