Skip to content

Commit

Permalink
rename guardrailConfiguration to guardrail
Browse files Browse the repository at this point in the history
  • Loading branch information
mazyu36 committed Jun 21, 2024
1 parent 57de8ad commit 2444620
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ const prompt = new BedrockInvokeModel(stack, 'Prompt', {
},
},
),
guardrailConfiguration: {
guardrail: {
guardrailIdentifier: guardrail.attrGuardrailId,
guardrailVersion: guardrail.attrVersion,
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@ export interface BedrockInvokeModelOutputProps {
}

/**
* Properties for the guardrail configuration.
* Properties for the guardrail.
*/
export interface GuardrailConfiguration {
export interface Guardrail {
/**
* The unique identifier of the guardrail that you want to use.
*/
Expand Down Expand Up @@ -126,7 +126,7 @@ export interface BedrockInvokeModelProps extends sfn.TaskStateBaseProps {
*
* @default - No guardrail is applied to the invocation.
*/
readonly guardrailConfiguration?: GuardrailConfiguration;
readonly guardrail?: Guardrail;

/**
* Specifies whether to enable or disable the Bedrock trace.
Expand Down Expand Up @@ -173,7 +173,7 @@ export class BedrockInvokeModel extends sfn.TaskStateBase {
throw new Error('Output S3 object version is not supported.');
}

this.validateGuardrailConfiguration(props);
this.validateGuardrail(props);

this.taskPolicies = this.renderPolicyStatements();
}
Expand Down Expand Up @@ -220,15 +220,15 @@ export class BedrockInvokeModel extends sfn.TaskStateBase {
);
}

if (this.props.guardrailConfiguration) {
if (this.props.guardrail) {
policyStatements.push(
new iam.PolicyStatement({
actions: ['bedrock:ApplyGuardrail'],
resources: [
Stack.of(this).formatArn({
service: 'bedrock',
resource: 'guardrail',
resourceName: this.props.guardrailConfiguration.guardrailIdentifier,
resourceName: this.props.guardrail.guardrailIdentifier,
}),
],
}),
Expand Down Expand Up @@ -257,8 +257,8 @@ export class BedrockInvokeModel extends sfn.TaskStateBase {
Output: this.props.output?.s3Location ? {
S3Uri: `s3://${this.props.output.s3Location.bucketName}/${this.props.output.s3Location.objectKey}`,
} : undefined,
GuardrailIdentifier: this.props.guardrailConfiguration?.guardrailIdentifier,
GuardrailVersion: this.props.guardrailConfiguration?.guardrailVersion,
GuardrailIdentifier: this.props.guardrail?.guardrailIdentifier,
GuardrailVersion: this.props.guardrail?.guardrailVersion,
Trace: this.props.traceEnabled === undefined
? undefined
: this.props.traceEnabled
Expand All @@ -268,18 +268,18 @@ export class BedrockInvokeModel extends sfn.TaskStateBase {
};
}

private validateGuardrailConfiguration(props: BedrockInvokeModelProps) {
if (!props.guardrailConfiguration) return;
private validateGuardrail(props: BedrockInvokeModelProps) {
if (!props.guardrail) return;

const { guardrailIdentifier, guardrailVersion } = props.guardrailConfiguration;
const { guardrailIdentifier, guardrailVersion } = props.guardrail;

if (!Token.isUnresolved(guardrailIdentifier)) {
const guardrailConfigurationPattern = /^(([a-z0-9]+)|(arn:aws(-[^:]+)?:bedrock:[a-z0-9-]{1,20}:[0-9]{12}:guardrail\/[a-z0-9]+))$/;
if (!guardrailConfigurationPattern.test(guardrailIdentifier)) {
const guardrailPattern = /^(([a-z0-9]+)|(arn:aws(-[^:]+)?:bedrock:[a-z0-9-]{1,20}:[0-9]{12}:guardrail\/[a-z0-9]+))$/;
if (!guardrailPattern.test(guardrailIdentifier)) {
throw new Error(`You must set guardrailIdentifier to the id or the arn of Guardrail, got ${guardrailIdentifier}`);
}
if (props.contentType !== 'application/json') {
throw new Error(`You must set contentType to \'application/json\' when using guardrailConfiguration, got '${props.contentType}'.`);
throw new Error(`You must set contentType to \'application/json\' when using guardrail, got '${props.contentType}'.`);
}
if (guardrailIdentifier.length > 2048) {
throw new Error(`\`guardrailIdentifier\` length must be between 0 and 2048, got ${guardrailIdentifier.length}.`);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ describe('Invoke Model', () => {
}).toThrow(/Output S3 object version is not supported./);
});

test('guardrail configuration', () => {
test('guardrail', () => {
// GIVEN
const stack = new cdk.Stack();
const model = bedrock.ProvisionedModel.fromProvisionedModelArn(stack, 'Imported', 'arn:aws:bedrock:us-turbo-2:123456789012:provisioned-model/abc-123');
Expand All @@ -375,7 +375,7 @@ describe('Invoke Model', () => {
prompt: 'Hello world',
},
),
guardrailConfiguration: {
guardrail: {
guardrailIdentifier: 'arn:aws:bedrock:us-turbo-2:123456789012:guardrail/testid',
guardrailVersion: 'DRAFT',
},
Expand Down Expand Up @@ -409,7 +409,7 @@ describe('Invoke Model', () => {
});
});

test('guardrail configuration fails when invalid guardrailIdentifier is set', () => {
test('guardrail fails when invalid guardrailIdentifier is set', () => {
// GIVEN
const stack = new cdk.Stack();
const model = bedrock.ProvisionedModel.fromProvisionedModelArn(stack, 'Imported', 'arn:aws:bedrock:us-turbo-2:123456789012:provisioned-model/abc-123');
Expand All @@ -424,7 +424,7 @@ describe('Invoke Model', () => {
prompt: 'Hello world',
},
),
guardrailConfiguration: {
guardrail: {
guardrailIdentifier: 'invalid-id',
guardrailVersion: 'DRAFT',
},
Expand All @@ -433,7 +433,7 @@ describe('Invoke Model', () => {
}).toThrow('You must set guardrailIdentifier to the id or the arn of Guardrail, got invalid-id');
});

test('guardrail configuration fails when guardrailIdentifier length is invalid', () => {
test('guardrail fails when guardrailIdentifier length is invalid', () => {
// GIVEN
const stack = new cdk.Stack();
const model = bedrock.ProvisionedModel.fromProvisionedModelArn(stack, 'Imported', 'arn:aws:bedrock:us-turbo-2:123456789012:provisioned-model/abc-123');
Expand All @@ -449,7 +449,7 @@ describe('Invoke Model', () => {
prompt: 'Hello world',
},
),
guardrailConfiguration: {
guardrail: {
guardrailIdentifier,
guardrailVersion: 'DRAFT',
},
Expand All @@ -458,7 +458,7 @@ describe('Invoke Model', () => {
}).toThrow(`\`guardrailIdentifier\` length must be between 0 and 2048, got ${guardrailIdentifier.length}.`);
});

test('guardrail configuration fails when invalid guardrailVersion is set', () => {
test('guardrail fails when invalid guardrailVersion is set', () => {
// GIVEN
const stack = new cdk.Stack();
const model = bedrock.ProvisionedModel.fromProvisionedModelArn(stack, 'Imported', 'arn:aws:bedrock:us-turbo-2:123456789012:provisioned-model/abc-123');
Expand All @@ -473,7 +473,7 @@ describe('Invoke Model', () => {
prompt: 'Hello world',
},
),
guardrailConfiguration: {
guardrail: {
guardrailIdentifier: 'abcdef',
guardrailVersion: 'test',
},
Expand All @@ -482,7 +482,7 @@ describe('Invoke Model', () => {
}).toThrow('guardrailVersion must match the ^(([1-9][0-9]{0,7})|(DRAFT))$ pattern, got test');
});

test('guardrail configuration fails when contentType is not \'application/json\'', () => {
test('guardrail fails when contentType is not \'application/json\'', () => {
// GIVEN
const stack = new cdk.Stack();
const model = bedrock.ProvisionedModel.fromProvisionedModelArn(stack, 'Imported', 'arn:aws:bedrock:us-turbo-2:123456789012:provisioned-model/abc-123');
Expand All @@ -497,13 +497,13 @@ describe('Invoke Model', () => {
prompt: 'Hello world',
},
),
guardrailConfiguration: {
guardrail: {
guardrailIdentifier: 'abcdef',
guardrailVersion: 'DRAFT',
},
});
// THEN
}).toThrow('You must set contentType to \'application/json\' when using guardrailConfiguration, got \'text/plain\'.');
}).toThrow('You must set contentType to \'application/json\' when using guardrail, got \'text/plain\'.');
});

test('trace configuration', () => {
Expand Down

0 comments on commit 2444620

Please sign in to comment.