Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AIRFLOW-2524]Add SageMaker Batch Inference #3767

Merged
merged 7 commits into from
Sep 14, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 70 additions & 11 deletions airflow/contrib/hooks/sagemaker_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ class SageMakerHook(AwsHook):
sagemaker_conn_id is required for using
the config stored in db for training/tuning
"""
non_terminal_states = {'InProgress', 'Stopping', 'Stopped'}
failed_states = {'Failed'}

def __init__(self,
sagemaker_conn_id=None,
Expand Down Expand Up @@ -96,9 +98,9 @@ def check_status(self, non_terminal_states,
describe_function, *args):
"""
:param non_terminal_states: the set of non_terminal states
:type non_terminal_states: dict
:type non_terminal_states: set
:param failed_state: the set of failed states
:type failed_state: dict
:type failed_state: set
:param key: the key of the response dict
that points to the state
:type key: str
Expand Down Expand Up @@ -177,7 +179,7 @@ def create_training_job(self, training_job_config, wait_for_completion=True):
:param training_job_config: the config for training
:type training_job_config: dict
:param wait_for_completion: if the program should keep running until job finishes
:param wait_for_completion: bool
:type wait_for_completion: bool
:return: A dict that contains ARN of the training job.
"""
if self.use_db_config:
Expand All @@ -194,8 +196,8 @@ def create_training_job(self, training_job_config, wait_for_completion=True):
response = self.conn.create_training_job(
**training_job_config)
if wait_for_completion:
self.check_status(['InProgress', 'Stopping', 'Stopped'],
['Failed'],
self.check_status(SageMakerHook.non_terminal_states,
SageMakerHook.failed_states,
'TrainingJobStatus',
self.describe_training_job,
training_job_config['TrainingJobName'])
Expand All @@ -213,8 +215,8 @@ def create_tuning_job(self, tuning_job_config, wait_for_completion=True):
if self.use_db_config:
if not self.sagemaker_conn_id:
raise AirflowException(
"sagemaker connection id must be present to \
read sagemaker tunning job configuration.")
"SageMaker connection id must be present to \
read SageMaker tunning job configuration.")

sagemaker_conn = self.get_connection(self.sagemaker_conn_id)

Expand All @@ -226,13 +228,59 @@ def create_tuning_job(self, tuning_job_config, wait_for_completion=True):
response = self.conn.create_hyper_parameter_tuning_job(
**tuning_job_config)
if wait_for_completion:
self.check_status(['InProgress', 'Stopping', 'Stopped'],
['Failed'],
self.check_status(SageMakerHook.non_terminal_states,
SageMakerHook.failed_states,
'HyperParameterTuningJobStatus',
self.describe_tuning_job,
tuning_job_config['HyperParameterTuningJobName'])
return response

def create_transform_job(self, transform_job_config, wait_for_completion=True):
"""
Create a transform job
:param transform_job_config: the config for transform job
:type transform_job_config: dict
:param wait_for_completion:
if the program should keep running until job finishes
:type wait_for_completion: bool
:return: A dict that contains ARN of the transform job.
"""
if self.use_db_config:
if not self.sagemaker_conn_id:
raise AirflowException(
"SageMaker connection id must be present to \
read SageMaker transform job configuration.")

sagemaker_conn = self.get_connection(self.sagemaker_conn_id)

config = sagemaker_conn.extra_dejson.copy()
transform_job_config.update(config)

self.check_for_url(transform_job_config
['TransformInput']['DataSource']
['S3DataSource']['S3Uri'])

response = self.conn.create_transform_job(
**transform_job_config)
if wait_for_completion:
self.check_status(SageMakerHook.non_terminal_states,
SageMakerHook.failed_states,
'TransformJobStatus',
self.describe_transform_job,
transform_job_config['TransformJobName'])
return response

def create_model(self, model_config):
"""
Create a model job
:param model_config: the config for model
:type model_config: dict
:return: A dict that contains ARN of the model.
"""

return self.conn.create_model(
**model_config)

def describe_training_job(self, training_job_name):
"""
:param training_job_name: the name of the training job
Expand All @@ -245,11 +293,22 @@ def describe_training_job(self, training_job_name):

def describe_tuning_job(self, tuning_job_name):
"""
:param tuning_job_name: the name of the training job
:type tuning_job_name: str
:param tuning_job_name: the name of the tuning job
:type tuning_job_name: string
Return the tuning job info associated with the current job_name
:return: A dict contains all the tuning job info
"""
return self.conn\
.describe_hyper_parameter_tuning_job(
HyperParameterTuningJobName=tuning_job_name)

def describe_transform_job(self, transform_job_name):
"""
:param transform_job_name: the name of the transform job
:type transform_job_name: string
Return the transform job info associated with the current job_name
:return: A dict contains all the transform job info
"""
return self.conn\
.describe_transform_job(
TransformJobName=transform_job_name)
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class SageMakerCreateTrainingJobOperator(BaseOperator):
until training job finishes
:type wait_for_completion: bool
:param check_interval: if wait is set to be true, this is the time interval
which the operator will check the status of the training job
in seconds which the operator will check the status of the training job
:type check_interval: int
:param max_ingestion_time: if wait is set to be true, the operator will fail
if the training job hasn't finish within the max_ingestion_time
Expand Down
132 changes: 132 additions & 0 deletions airflow/contrib/operators/sagemaker_create_transform_job_operator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
# -*- coding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from airflow.contrib.hooks.sagemaker_hook import SageMakerHook
from airflow.models import BaseOperator
from airflow.utils.decorators import apply_defaults
from airflow.exceptions import AirflowException


class SageMakerCreateTransformJobOperator(BaseOperator):
"""
Initiate a SageMaker transform

This operator returns The ARN of the model created in Amazon SageMaker

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we order the docstring in the same order as the arguments?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@troychen728 has fixed it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

:param sagemaker_conn_id: The SageMaker connection ID to use.
:type sagemaker_conn_id: string
:param transform_job_config:
The configuration necessary to start a transform job (templated)
:type transform_job_config: dict
:param model_config:
The configuration necessary to create a model, the default is none
which means that user should provide a created model in transform_job_config
If given, will be used to create a model before creating transform job
:type model_config: dict
:param use_db_config: Whether or not to use db config
associated with sagemaker_conn_id.
If set to true, will automatically update the transform config
with what's in db, so the db config doesn't need to
included everything, but what's there does replace the ones
in the transform_job_config, so be careful
:type use_db_config: bool
:param region_name: The AWS region_name
:type region_name: string
:param wait_for_completion: if the program should keep running until job finishes
:type wait_for_completion: bool
:param check_interval: if wait is set to be true, this is the time interval
in seconds which the operator will check the status of the transform job
:type check_interval: int
:param max_ingestion_time: if wait is set to be true, the operator will fail
if the transform job hasn't finish within the max_ingestion_time
(Caution: be careful to set this parameters because transform can take very long)
:type max_ingestion_time: int
:param aws_conn_id: The AWS connection ID to use.
:type aws_conn_id: string

**Example**:
The following operator would start a transform job when executed

sagemaker_transform =
SageMakerCreateTransformJobOperator(
task_id='sagemaker_transform',
transform_job_config=config_transform,
model_config=config_model,
region_name='us-west-2'
sagemaker_conn_id='sagemaker_customers_conn',
use_db_config=True,
aws_conn_id='aws_customers_conn'
)
"""

template_fields = ['transform_job_config']
template_ext = ()
ui_color = '#ededed'

@apply_defaults
def __init__(self,
sagemaker_conn_id=None,
transform_job_config=None,
model_config=None,
use_db_config=False,
region_name=None,
wait_for_completion=True,
check_interval=2,
max_ingestion_time=None,
*args, **kwargs):
super(SageMakerCreateTransformJobOperator, self).__init__(*args, **kwargs)

self.sagemaker_conn_id = sagemaker_conn_id
self.transform_job_config = transform_job_config
self.model_config = model_config
self.use_db_config = use_db_config
self.region_name = region_name
self.wait_for_completion = wait_for_completion
self.check_interval = check_interval
self.max_ingestion_time = max_ingestion_time

def execute(self, context):
sagemaker = SageMakerHook(
sagemaker_conn_id=self.sagemaker_conn_id,
use_db_config=self.use_db_config,
region_name=self.region_name,
check_interval=self.check_interval,
max_ingestion_time=self.max_ingestion_time
)

if self.model_config:
self.log.info(
"Creating SageMaker Model %s for transform job"
% self.model_config['ModelName']
)
sagemaker.create_model(self.model_config)

self.log.info(
"Creating SageMaker transform Job %s."
% self.transform_job_config['TransformJobName']
)
response = sagemaker.create_transform_job(
self.transform_job_config,
wait_for_completion=self.wait_for_completion)
if not response['ResponseMetadata']['HTTPStatusCode'] \
== 200:
raise AirflowException(
'Sagemaker transform Job creation failed: %s' % response)
else:
return response
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class SageMakerCreateTuningJobOperator(BaseOperator):
until tuning job finishes
:type wait_for_completion: bool
:param check_interval: if wait is set to be true, this is the time interval
which the operator will check the status of the tuning job
in seconds which the operator will check the status of the tuning job
:type check_interval: int
:param max_ingestion_time: if wait is set to be true, the operator will fail
if the tuning job hasn't finish within the max_ingestion_time
Expand Down
4 changes: 2 additions & 2 deletions airflow/contrib/sensors/sagemaker_training_sensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,10 @@ def __init__(self,
self.region_name = region_name

def non_terminal_states(self):
return ['InProgress', 'Stopping', 'Stopped']
return SageMakerHook.non_terminal_states

def failed_states(self):
return ['Failed']
return SageMakerHook.failed_states

def get_sagemaker_response(self):
sagemaker = SageMakerHook(
Expand Down
69 changes: 69 additions & 0 deletions airflow/contrib/sensors/sagemaker_transform_sensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# -*- coding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from airflow.contrib.hooks.sagemaker_hook import SageMakerHook
from airflow.contrib.sensors.sagemaker_base_sensor import SageMakerBaseSensor
from airflow.utils.decorators import apply_defaults


class SageMakerTransformSensor(SageMakerBaseSensor):
"""
Asks for the state of the transform state until it reaches a terminal state.
The sensor will error if the job errors, throwing a AirflowException
containing the failure reason.

:param job_name: job_name of the transform job instance to check the state of
:type job_name: string
:param region_name: The AWS region_name
:type region_name: string
"""

template_fields = ['job_name']
template_ext = ()

@apply_defaults
def __init__(self,
job_name,
region_name=None,
*args,
**kwargs):
super(SageMakerTransformSensor, self).__init__(*args, **kwargs)
self.job_name = job_name
self.region_name = region_name

def non_terminal_states(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make this one static?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not so sure why should I or what difference does it make if I make this static. Can you shed a little bit light on this? Thanks. Also, I am sorry but I might not be able to give feedback to the comments very quickly, because I am really busy with school stuff these few days.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the function does not use anything from the class itself, it is prettier to make it static since it is then easier to reuse in other classes/functions.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Made non_terminal_states and failed_states static under SageMakerHook.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Much beter, thanks

return SageMakerHook.non_terminal_states

def failed_states(self):
return SageMakerHook.failed_states

def get_sagemaker_response(self):
sagemaker = SageMakerHook(
aws_conn_id=self.aws_conn_id,
region_name=self.region_name
)

self.log.info('Poking Sagemaker Transform Job %s', self.job_name)
return sagemaker.describe_transform_job(self.job_name)

def get_failed_reason_from_response(self, response):
return response['FailureReason']

def state_from_response(self, response):
return response['TransformJobStatus']
4 changes: 2 additions & 2 deletions airflow/contrib/sensors/sagemaker_tuning_sensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,10 @@ def __init__(self,
self.region_name = region_name

def non_terminal_states(self):
return ['InProgress', 'Stopping', 'Stopped']
return SageMakerHook.non_terminal_states

def failed_states(self):
return ['Failed']
return SageMakerHook.failed_states

def get_sagemaker_response(self):
sagemaker = SageMakerHook(
Expand Down
Loading