Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: azure entra and managed identity integration #580

Merged
merged 10 commits into from
Nov 4, 2024
8 changes: 8 additions & 0 deletions src/handlers/handlerUtils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1012,6 +1012,10 @@ export function constructConfigFromRequestHeaders(
deploymentId: requestHeaders[`x-${POWERED_BY}-azure-deployment-id`],
apiVersion: requestHeaders[`x-${POWERED_BY}-azure-api-version`],
azureModelName: requestHeaders[`x-${POWERED_BY}-azure-model-name`],
azureEntraClientId: requestHeaders[`x-${POWERED_BY}-azure-entra-client-id`],
azureEntraClientSecret:
requestHeaders[`x-${POWERED_BY}-azure-entra-client-secret`],
azureEntraTenantId: requestHeaders[`x-${POWERED_BY}-azure-entra-tenant-id`],
};

const stabilityAiConfig = {
Expand All @@ -1030,6 +1034,10 @@ export function constructConfigFromRequestHeaders(
requestHeaders[`x-${POWERED_BY}-azure-deployment-type`],
azureApiVersion: requestHeaders[`x-${POWERED_BY}-azure-api-version`],
azureEndpointName: requestHeaders[`x-${POWERED_BY}-azure-endpoint-name`],
azureEntraClientId: requestHeaders[`x-${POWERED_BY}-azure-entra-client-id`],
azureEntraClientSecret:
requestHeaders[`x-${POWERED_BY}-azure-entra-client-secret`],
azureEntraTenantId: requestHeaders[`x-${POWERED_BY}-azure-entra-tenant-id`],
};

const bedrockConfig = {
Expand Down
22 changes: 19 additions & 3 deletions src/providers/azure-ai-inference/api.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { GITHUB } from '../../globals';
import { getAccessTokenFromEntraId } from '../azure-openai/utils';
import { ProviderAPIConfig } from '../types';

const AzureAIInferenceAPI: ProviderAPIConfig = {
Expand All @@ -19,16 +20,31 @@ const AzureAIInferenceAPI: ProviderAPIConfig = {

return `https://${azureEndpointName}.${azureRegion}.inference.ml.azure.com/score`;
},
headers: ({ providerOptions }) => {
const { apiKey, azureDeploymentType, azureDeploymentName } =
providerOptions;
headers: async ({ providerOptions }) => {
const {
apiKey,
azureDeploymentType,
azureDeploymentName,
azureEntraClientId,
azureEntraClientSecret,
azureEntraTenantId,
} = providerOptions;
const headers: Record<string, string> = {
Authorization: `Bearer ${apiKey}`,
'extra-parameters': 'ignore',
};
if (azureDeploymentType === 'managed' && azureDeploymentName) {
headers['azureml-model-deployment'] = azureDeploymentName;
}
if (azureEntraClientId && azureEntraClientSecret && azureEntraTenantId) {
const accessToken = await getAccessTokenFromEntraId(
azureEntraTenantId,
azureEntraClientId,
azureEntraClientSecret,
'https://cognitiveservices.azure.com/decision/.default'
);
headers['Authorization'] = `Bearer ${accessToken}`;
}
return headers;
},
getEndpoint: ({ providerOptions, fn }) => {
Expand Down
15 changes: 14 additions & 1 deletion src/providers/azure-openai/api.ts
Original file line number Diff line number Diff line change
@@ -1,16 +1,29 @@
import { ProviderAPIConfig } from '../types';
import { getAccessTokenFromEntraId } from './utils';

const AzureOpenAIAPIConfig: ProviderAPIConfig = {
getBaseURL: ({ providerOptions }) => {
const { resourceName, deploymentId } = providerOptions;
return `https://${resourceName}.openai.azure.com/openai/deployments/${deploymentId}`;
},
headers: ({ providerOptions, fn }) => {
headers: async ({ providerOptions, fn }) => {
const headersObj: Record<string, string> = {
'api-key': `${providerOptions.apiKey}`,
};
if (fn === 'createTranscription' || fn === 'createTranslation')
headersObj['Content-Type'] = 'multipart/form-data';
const { azureEntraClientId, azureEntraClientSecret, azureEntraTenantId } =
providerOptions;
if (azureEntraClientId && azureEntraClientSecret && azureEntraTenantId) {
const accessToken = await getAccessTokenFromEntraId(
azureEntraTenantId,
azureEntraClientId,
azureEntraClientSecret,
'https://cognitiveservices.azure.com/decision/.default'
);
headersObj['Authorization'] = `Bearer ${accessToken}`;
}

return headersObj;
},
getEndpoint: ({ providerOptions, fn }) => {
Expand Down
29 changes: 29 additions & 0 deletions src/providers/azure-openai/utils.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
export async function getAccessTokenFromEntraId(
tenantId: string,
clientId: string,
clientSecret: string,
scope = 'https://openai.azure.com/.default'
) {
const url = `https://login.microsoftonline.com/${tenantId}/oauth2/v2.0/token`;

const params = new URLSearchParams({
client_id: clientId,
client_secret: clientSecret,
scope: scope,
grant_type: 'client_credentials',
});

const response = await fetch(url, {
method: 'POST',
headers: {
'Content-Type': 'application/x-www-form-urlencoded',
},
body: params,
});

if (!response.ok) {
throw new Error(`Error fetching access token: ${response.statusText}`);
}
const data: { access_token: string } = await response.json();
return data.access_token;
}
3 changes: 3 additions & 0 deletions src/types/requestBody.ts
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ export interface Options {
apiVersion?: string;
adAuth?: string;
azureModelName?: string;
azureEntraClientId?: string;
azureEntraClientSecret?: string;
azureEntraTenantId?: string;
/** Workers AI specific */
workersAiAccountId?: string;
/** The parameter to set custom base url */
Expand Down
Loading