Skip to content

Commit

Permalink
bedrock create batch
Browse files Browse the repository at this point in the history
  • Loading branch information
narengogi committed Dec 9, 2024
1 parent 1a0dd19 commit a39994b
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 1 deletion.
6 changes: 5 additions & 1 deletion src/providers/bedrock/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ const BedrockAPIConfig: ProviderAPIConfig = {
providerOptions.awsSecretAccessKey
);
}
return `https://bedrock-runtime.${providerOptions.awsRegion || 'us-east-1'}.amazonaws.com`;
const isAWSControlPlaneEndpoint = fn === 'createBatch';
return `https://${isAWSControlPlaneEndpoint ? 'bedrock' : 'bedrock-runtime'}.${providerOptions.awsRegion || 'us-east-1'}.amazonaws.com`;
},
headers: async ({
c,
Expand Down Expand Up @@ -129,6 +130,9 @@ const BedrockAPIConfig: ProviderAPIConfig = {
case 'imageGenerate': {
return endpoint;
}
case 'createBatch': {
return '/model-invocation-job';
}
default:
return '';
}
Expand Down
56 changes: 56 additions & 0 deletions src/providers/bedrock/createBatch.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import { BEDROCK } from '../../globals';
import { CreateBatchResponse, ErrorResponse, ProviderConfig } from '../types';
import { generateInvalidProviderResponseError } from '../utils';
import { BedrockErrorResponseTransform } from './chatComplete';
import { BedrockErrorResponse } from './embed';

export const BedrockCreateBatchConfig: ProviderConfig = {
model: {
param: 'modelId',
required: true,
},
input_file_id: {
param: 'inputDataConfig',
required: true,
transform: (params: CreateBatchResponse) => {
return {
s3Uri: params.input_file_id,
};
},
},
jobName: {
param: 'jobName',
transform: (params: CreateBatchResponse) => {
return 'portkey-batch-job-' + crypto.randomUUID();
},
},
outputDataConfig: {
param: 'outputDataConfig',
required: true,
},
roleArn: {
param: 'roleArn',
required: true,
},
};

export const BedrockCreateBatchResponseTransform: (
response: CreateBatchResponse | BedrockErrorResponse,
responseStatus: number
) => CreateBatchResponse | ErrorResponse = (response, responseStatus) => {
if (responseStatus !== 200) {
const errorResposne = BedrockErrorResponseTransform(
response as BedrockErrorResponse
);
if (errorResposne) return errorResposne;
}

if ('jobArn' in response) {
return {
id: response.jobArn as string,
object: 'batch',
};
}

return generateInvalidProviderResponseError(response, BEDROCK);
};
6 changes: 6 additions & 0 deletions src/providers/bedrock/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ import {
BedrockTitanCompleteStreamChunkTransform,
} from './complete';
import { BEDROCK_STABILITY_V1_MODELS } from './constants';
import {
BedrockCreateBatchConfig,
BedrockCreateBatchResponseTransform,
} from './createBatch';
import {
BedrockCohereEmbedConfig,
BedrockCohereEmbedResponseTransform,
Expand Down Expand Up @@ -182,8 +186,10 @@ const BedrockConfig: ProviderConfigs = {
if (!config.responseTransforms) {
config.responseTransforms = {
uploadFile: BedrockUploadFileResponseTransform,
createBatch: BedrockCreateBatchResponseTransform,
};
}
config.createBatch = BedrockCreateBatchConfig;
return config;
},
};
Expand Down
3 changes: 3 additions & 0 deletions src/providers/bedrock/types.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
interface BedrockCreateBatchResponse {
jobArn: string;
}

0 comments on commit a39994b

Please sign in to comment.