diff --git a/src/handlers/handlerUtils.ts b/src/handlers/handlerUtils.ts index b06e4cf37..991026a88 100644 --- a/src/handlers/handlerUtils.ts +++ b/src/handlers/handlerUtils.ts @@ -1013,6 +1013,13 @@ export function constructConfigFromRequestHeaders( resourceName: requestHeaders[`x-${POWERED_BY}-azure-resource-name`], deploymentId: requestHeaders[`x-${POWERED_BY}-azure-deployment-id`], apiVersion: requestHeaders[`x-${POWERED_BY}-azure-api-version`], + azureAuthMode: requestHeaders[`x-${POWERED_BY}-azure-auth-mode`], + azureManagedClientId: + requestHeaders[`x-${POWERED_BY}-azure-managed-client-id`], + azureEntraClientId: requestHeaders[`x-${POWERED_BY}-azure-entra-client-id`], + azureEntraClientSecret: + requestHeaders[`x-${POWERED_BY}-azure-entra-client-secret`], + azureEntraTenantId: requestHeaders[`x-${POWERED_BY}-azure-entra-tenant-id`], azureModelName: requestHeaders[`x-${POWERED_BY}-azure-model-name`], }; diff --git a/src/providers/azure-openai/api.ts b/src/providers/azure-openai/api.ts index 7ccb5a58d..7010a4346 100644 --- a/src/providers/azure-openai/api.ts +++ b/src/providers/azure-openai/api.ts @@ -1,13 +1,46 @@ import { ProviderAPIConfig } from '../types'; +import { + getAccessTokenFromEntraId, + getAzureManagedIdentityToken, +} from './utils'; const AzureOpenAIAPIConfig: ProviderAPIConfig = { getBaseURL: ({ providerOptions }) => { const { resourceName, deploymentId } = providerOptions; return `https://${resourceName}.openai.azure.com/openai/deployments/${deploymentId}`; }, - headers: ({ providerOptions, fn }) => { + headers: async ({ providerOptions, fn }) => { + const { apiKey, azureAuthMode } = providerOptions; + + if (azureAuthMode === 'entra') { + const { azureEntraTenantId, azureEntraClientId, azureEntraClientSecret } = + providerOptions; + if (azureEntraTenantId && azureEntraClientId && azureEntraClientSecret) { + const scope = 'https://cognitiveservices.azure.com/.default'; + const accessToken = await getAccessTokenFromEntraId( + azureEntraTenantId, + azureEntraClientId, + azureEntraClientSecret, + scope + ); + return { + Authorization: `Bearer ${accessToken}`, + }; + } + } + if (azureAuthMode === 'managed') { + const { azureManagedClientId } = providerOptions; + const resource = 'https://cognitiveservices.azure.com/'; + const accessToken = await getAzureManagedIdentityToken( + resource, + azureManagedClientId + ); + return { + Authorization: `Bearer ${accessToken}`, + }; + } const headersObj: Record = { - 'api-key': `${providerOptions.apiKey}`, + 'api-key': `${apiKey}`, }; if (fn === 'createTranscription' || fn === 'createTranslation') headersObj['Content-Type'] = 'multipart/form-data'; diff --git a/src/providers/azure-openai/utils.ts b/src/providers/azure-openai/utils.ts new file mode 100644 index 000000000..2082ea205 --- /dev/null +++ b/src/providers/azure-openai/utils.ts @@ -0,0 +1,60 @@ +export async function getAccessTokenFromEntraId( + tenantId: string, + clientId: string, + clientSecret: string, + scope = 'https://cognitiveservices.azure.com/.default' +) { + try { + const url = `https://login.microsoftonline.com/${tenantId}/oauth2/v2.0/token`; + const params = new URLSearchParams({ + client_id: clientId, + client_secret: clientSecret, + scope: scope, + grant_type: 'client_credentials', + }); + + const response = await fetch(url, { + method: 'POST', + headers: { + 'Content-Type': 'application/x-www-form-urlencoded', + }, + body: params, + }); + + if (!response.ok) { + const errorMessage = await response.text(); + console.log({ message: `Error from Entra ${errorMessage}` }); + return undefined; + } + const data: { access_token: string } = await response.json(); + return data.access_token; + } catch (error) { + console.log(error); + } +} + +export async function getAzureManagedIdentityToken( + resource: string, + clientId?: string +) { + try { + const response = await fetch( + `http://169.254.169.254/metadata/identity/oauth2/token?api-version=2018-02-01&resource=${encodeURIComponent(resource)}${clientId ? `&client_id=${encodeURIComponent(clientId)}` : ''}`, + { + method: 'GET', + headers: { + Metadata: 'true', + }, + } + ); + if (!response.ok) { + const errorMessage = await response.text(); + console.log({ message: `Error from Managed ${errorMessage}` }); + return undefined; + } + const data: { access_token: string } = await response.json(); + return data.access_token; + } catch (error) { + console.log({ error }); + } +} diff --git a/src/types/requestBody.ts b/src/types/requestBody.ts index 9c3efb030..5c3001024 100644 --- a/src/types/requestBody.ts +++ b/src/types/requestBody.ts @@ -60,6 +60,11 @@ export interface Options { apiVersion?: string; adAuth?: string; azureModelName?: string; + azureAuthMode?: string; // can be entra or managed + azureManagedClientId?: string; + azureEntraClientId?: string; + azureEntraClientSecret?: string; + azureEntraTenantId?: string; /** Workers AI specific */ workersAiAccountId?: string; /** The parameter to set custom base url */ @@ -143,6 +148,12 @@ export interface Targets { deploymentId?: string; apiVersion?: string; adAuth?: string; + azureAuthMode?: string; + azureManagedClientId?: string; + azureEntraClientId?: string; + azureEntraClientSecret?: string; + azureEntraTenantId?: string; + azureModelName?: string; /** provider option index picked based on weight in loadbalance mode */ index?: number; cache?: CacheSettings | string; @@ -356,6 +367,11 @@ export interface ShortConfig { azureModelName?: string; workersAiAccountId?: string; apiVersion?: string; + azureAuthMode?: string; + azureManagedClientId?: string; + azureEntraClientId?: string; + azureEntraClientSecret?: string; + azureEntraTenantId?: string; customHost?: string; // Google Vertex AI specific vertexRegion?: string;