Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(stepfunctions-tasks): FastFile mode for SageMaker Training Job #26675

Merged
merged 12 commits into from
Aug 23, 2023
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,14 @@ export class SageMakerCreateTrainingJob extends sfn.TaskStateBase implements iam
throw new Error('Must define either an algorithm name or training image URI in the algorithm specification');
}

// validate the algorithmName if the algorithmName is defined
tmyoda marked this conversation as resolved.
Show resolved Hide resolved
if (props.algorithmSpecification.algorithmName) {
tmyoda marked this conversation as resolved.
Show resolved Hide resolved
const regex = /^(arn:aws[a-z\-]*:sagemaker:[a-z0-9\-]*:[0-9]{12}:[a-z\-]*\/)?([a-zA-Z0-9]([a-zA-Z0-9-]){0,62})(?<!-)$/;
if (!regex.test(props.algorithmSpecification.algorithmName)) {
throw new Error(`Value '${props.algorithmSpecification.algorithmName}' at 'algorithmName' must satisfy regular expression pattern: ${regex.source}`);
tmyoda marked this conversation as resolved.
Show resolved Hide resolved
}
}

// set the input mode to 'File' if not defined
this.algorithmSpecification = props.algorithmSpecification.trainingInputMode
? props.algorithmSpecification
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -408,3 +408,97 @@ test('Cannot create a SageMaker train task with both algorithm name and image na
}))
.toThrowError(/Must define either an algorithm name or training image URI in the algorithm specification/);
});

test('create a SageMaker train task with trainingImage', () => {

kaizencc marked this conversation as resolved.
Show resolved Hide resolved
const task = new SageMakerCreateTrainingJob(stack, 'SageMakerTrainingTask', {
trainingJobName: 'myTrainJob',
algorithmSpecification: {
trainingImage: tasks.DockerImage.fromJsonExpression(sfn.JsonPath.stringAt('$.Training.imageName')),
},
inputDataConfig: [
{
channelName: 'train',
dataSource: {
s3DataSource: {
s3DataType: tasks.S3DataType.S3_PREFIX,
s3Location: tasks.S3Location.fromJsonExpression('$.S3Bucket'),
},
},
},
],
outputDataConfig: {
s3OutputLocation: tasks.S3Location.fromBucket(s3.Bucket.fromBucketName(stack, 'Bucket', 'mybucket'), 'myoutputpath/'),
},
});

// THEN
expect(stack.resolve(task.toStateJson())).toMatchObject({
Parameters: {
AlgorithmSpecification: {
'TrainingImage.$': '$.Training.imageName',
'TrainingInputMode': 'File',
},
},
});
});

test('create a SageMaker train task with image URI algorithmName', () => {

kaizencc marked this conversation as resolved.
Show resolved Hide resolved
const task = new SageMakerCreateTrainingJob(stack, 'SageMakerTrainingTask', {
trainingJobName: 'myTrainJob',
algorithmSpecification: {
algorithmName: 'arn:aws:sagemaker:us-east-1:123456789012:algorithm/scikit-decision-trees',
trainingInputMode: tasks.InputMode.FILE,
},
inputDataConfig: [
{
channelName: 'train',
dataSource: {
s3DataSource: {
s3DataType: tasks.S3DataType.S3_PREFIX,
s3Location: tasks.S3Location.fromJsonExpression('$.S3Bucket'),
},
},
},
],
outputDataConfig: {
s3OutputLocation: tasks.S3Location.fromBucket(s3.Bucket.fromBucketName(stack, 'Bucket', 'mybucket'), 'myoutputpath/'),
},
});

// THEN
expect(stack.resolve(task.toStateJson())).toMatchObject({
Parameters: {
AlgorithmSpecification: {
AlgorithmName: 'arn:aws:sagemaker:us-east-1:123456789012:algorithm/scikit-decision-trees',
},
},
});
});

test('Cannot create a SageMaker train task with incorrect algorithmName', () => {

kaizencc marked this conversation as resolved.
Show resolved Hide resolved
expect(() => new SageMakerCreateTrainingJob(stack, 'SageMakerTrainingTask', {
trainingJobName: 'myTrainJob',
algorithmSpecification: {
algorithmName: 'Blazing_Text', // underscores are not allowed
trainingInputMode: tasks.InputMode.FILE,
},
inputDataConfig: [
{
channelName: 'train',
dataSource: {
s3DataSource: {
s3DataType: tasks.S3DataType.S3_PREFIX,
s3Location: tasks.S3Location.fromJsonExpression('$.S3Bucket'),
},
},
},
],
outputDataConfig: {
s3OutputLocation: tasks.S3Location.fromBucket(s3.Bucket.fromBucketName(stack, 'Bucket', 'mybucket'), 'myoutputpath/'),
},
}))
.toThrowError(/'Blazing_Text' at 'algorithmName' must satisfy regular expression pattern/);
});