Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AIRFLOW-2524]Add SageMaker Batch Inference #3767

Merged
merged 7 commits into from
Sep 14, 2018
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 68 additions & 3 deletions airflow/contrib/hooks/sagemaker_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,16 @@ def check_valid_tuning_input(self, tuning_config):
self.check_for_url(channel['DataSource']
['S3DataSource']['S3Uri'])

def check_valid_transform_input(self, transform_config):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function feels a bit silly to me, why not check_for_url directly?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@troychen728 fixed it.

"""
Run checks before a transform job starts
:param transform_config: transform_config
:type transform_config: dict
:return: None
"""
self.check_for_url(transform_config
['TransformInput']['DataSource']['S3Uri'])

def check_status(self, non_terminal_states,
failed_state, key,
describe_function, *args):
Expand Down Expand Up @@ -206,8 +216,8 @@ def create_tuning_job(self, tuning_job_config):
if self.use_db_config:
if not self.sagemaker_conn_id:
raise AirflowException(
"sagemaker connection id must be present to \
read sagemaker tunning job configuration.")
"SageMaker connection id must be present to \
read SageMaker tunning job configuration.")

sagemaker_conn = self.get_connection(self.sagemaker_conn_id)

Expand All @@ -219,6 +229,50 @@ def create_tuning_job(self, tuning_job_config):
return self.conn.create_hyper_parameter_tuning_job(
**tuning_job_config)

def create_transform_job(self, transform_job_config, wait_for_completion=True):
"""
Create a tuning job
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not really in line with the name of the function, right? create_transform_job vs Create a tuning job

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@troychen728 fixed it.

:param transform_job_config: the config for transform job
:type transform_job_config: dict
:param wait_for_completion:
if the program should keep running until job finishes
:param wait_for_completion: bool
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be :type

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@troychen728 has fixed it.

:return: A dict that contains ARN of the transform job.
"""
if self.use_db_config:
if not self.sagemaker_conn_id:
raise AirflowException(
"SageMaker connection id must be present to \
read SageMaker transform job configuration.")

sagemaker_conn = self.get_connection(self.sagemaker_conn_id)

config = sagemaker_conn.extra_dejson.copy()
transform_job_config.update(config)

self.check_valid_transform_input(transform_job_config)

response = self.conn.create_transform_job(
**transform_job_config)
if wait_for_completion:
self.check_status(['InProgress', 'Stopping', 'Stopped'],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


    def check_status(self, non_terminal_states,
                     failed_state, key,
                     describe_function, *args):
        """
        :param non_terminal_states: the set of non_terminal states
        :type non_terminal_states: dict
        :param failed_state: the set of failed states
        :type failed_state: dict
        :param key: the key of the response dict
        that points to the state
        :type key: string
        :param describe_function: the function used to retrieve the status
        :type describe_function: python callable
        :param args: the arguments for the function
        :return: None
        """

The non_terminal_states and failed_state should be dict's according to the docs. Can we make these one set's as well? {'InProgress', 'Stopping', 'Stopped'}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Made them into set instead of dict

['Failed'],
'TransformJobStatus',
self.describe_transform_job,
transform_job_config['TransformJobName'])
return response

def create_model(self, model_config):
"""
Create a tuning job
:param model_config: the config for model
:type model_config: dict
:return: A dict that contains ARN of the model.
"""

return self.conn.create_model(
**model_config)

def describe_training_job(self, training_job_name):
"""
:param training_job_name: the name of the training job
Expand All @@ -231,11 +285,22 @@ def describe_training_job(self, training_job_name):

def describe_tuning_job(self, tuning_job_name):
"""
:param tuning_job_name: the name of the training job
:param tuning_job_name: the name of the tuning job
:type tuning_job_name: string
Return the tuning job info associated with the current job_name
:return: A dict contains all the tuning job info
"""
return self.conn\
.describe_hyper_parameter_tuning_job(
HyperParameterTuningJobName=tuning_job_name)

def describe_transform_job(self, transform_job_name):
"""
:param transform_job_name: the name of the transform job
:type transform_job_name: string
Return the transform job info associated with the current job_name
:return: A dict contains all the transform job info
"""
return self.conn\
.describe_transform_job(
TransformJobName=transform_job_name)
133 changes: 133 additions & 0 deletions airflow/contrib/operators/sagemaker_create_transform_job_operator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# -*- coding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from airflow.contrib.hooks.sagemaker_hook import SageMakerHook
from airflow.models import BaseOperator
from airflow.utils.decorators import apply_defaults
from airflow.exceptions import AirflowException


class SageMakerCreateTransformJobOperator(BaseOperator):

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Trim the \n please :-)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@troychen728 has fixed it.

"""
Initiate a SageMaker transform

This operator returns The ARN of the model created in Amazon SageMaker

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we order the docstring in the same order as the arguments?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@troychen728 has fixed it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

:param sagemaker_conn_id: The SageMaker connection ID to use.
:type sagemaker_conn_id: string
:param transform_job_config:
The configuration necessary to start a transform job (templated)
:type transform_job_config: dict
:param model_config:
The configuration necessary to create a model, the default is none
which means that user should provide a created model in transform_job_config
If given, will be used to create a model before creating transform job
:type model_config: dict
:param region_name: The AWS region_name
:type region_name: string
:param use_db_config: Whether or not to use db config
associated with sagemaker_conn_id.
If set to true, will automatically update the transform config
with what's in db, so the db config doesn't need to
included everything, but what's there does replace the ones
in the transform_job_config, so be careful
:type use_db_config: bool
:param wait_for_completion: if the program should keep running until job finishes
:type wait_for_completion: bool
:param check_interval: if wait is set to be true, this is the time interval
which the operator will check the status of the transform job
:type check_interval: int
:param max_ingestion_time: if wait is set to be true, the operator will fail
if the transform job hasn't finish within the max_ingestion_time
(Caution: be careful to set this parameters because transform can take very long)
:type max_ingestion_time: int
:param aws_conn_id: The AWS connection ID to use.
:type aws_conn_id: string

**Example**:
The following operator would start a transform job when executed

sagemaker_transform =
SageMakerCreateTransformJobOperator(
task_id='sagemaker_transform',
transform_job_config=config_transform,
model_config=config_model,
region_name='us-west-2'
sagemaker_conn_id='sagemaker_customers_conn',
use_db_config=True,
aws_conn_id='aws_customers_conn'
)
"""

template_fields = ['transform_job_config']
template_ext = ()
ui_color = '#ededed'

@apply_defaults
def __init__(self,
transform_job_config=None,
model_config=None,
region_name=None,
sagemaker_conn_id=None,
use_db_config=False,
wait_for_completion=False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you change this one to True by default, this is how the other operators are behaving as well.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@troychen728 has fixed it.

check_interval=2,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a time unit here? For example, _seconds

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you suggest to change the name, or add a documentation?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated docstring to say the unit for check_interval is in seconds. Please let me know if you have a strong preference that this should be in variable name.

max_ingestion_time=None,
*args, **kwargs):
super(SageMakerCreateTransformJobOperator, self).__init__(*args, **kwargs)

self.sagemaker_conn_id = sagemaker_conn_id
self.transform_job_config = transform_job_config
self.model_config = model_config
self.use_db_config = use_db_config
self.region_name = region_name
self.wait_for_completion = wait_for_completion
self.check_interval = check_interval
self.max_ingestion_time = max_ingestion_time

def execute(self, context):
sagemaker = SageMakerHook(
sagemaker_conn_id=self.sagemaker_conn_id,
use_db_config=self.use_db_config,
region_name=self.region_name,
check_interval=self.check_interval,
max_ingestion_time=self.max_ingestion_time
)

if self.model_config:
self.log.info(
"Creating SageMaker Model %s for transform job"
% self.model_config['ModelName']
)
sagemaker.create_model(self.model_config)

self.log.info(
"Creating SageMaker transform Job %s."
% self.transform_job_config['TransformJobName']
)
response = sagemaker.create_transform_job(
self.transform_job_config,
wait_for_completion=self.wait_for_completion)
if not response['ResponseMetadata']['HTTPStatusCode'] \
== 200:
raise AirflowException(
'Sagemaker transform Job creation failed: %s' % response)
else:
return response
67 changes: 67 additions & 0 deletions airflow/contrib/sensors/sagemaker_transform_sensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# -*- coding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from airflow.contrib.hooks.sagemaker_hook import SageMakerHook
from airflow.contrib.sensors.sagemaker_base_sensor import SageMakerBaseSensor
from airflow.utils.decorators import apply_defaults


class SageMakerTransformSensor(SageMakerBaseSensor):
"""
Asks for the state of the transform state until it reaches a terminal state.
The sensor will error if the job errors, throwing a AirflowException
containing the failure reason.

:param job_name: job_name of the transform job instance to check the state of
:type job_name: string
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing region_name here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@troychen728 has fixed it.


template_fields = ['job_name']
template_ext = ()

@apply_defaults
def __init__(self,
job_name,
region_name=None,
*args,
**kwargs):
super(SageMakerTransformSensor, self).__init__(*args, **kwargs)
self.job_name = job_name
self.region_name = region_name

def non_terminal_states(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make this one static?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not so sure why should I or what difference does it make if I make this static. Can you shed a little bit light on this? Thanks. Also, I am sorry but I might not be able to give feedback to the comments very quickly, because I am really busy with school stuff these few days.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the function does not use anything from the class itself, it is prettier to make it static since it is then easier to reuse in other classes/functions.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Made non_terminal_states and failed_states static under SageMakerHook.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Much beter, thanks

return ['InProgress', 'Stopping', 'Stopped']
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we consolidate these settings somewhere? I see them repeated quite a lot. Also please change it to a set.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can write them into the base sensor because at least for now, all the terminal_state, failed_state are the same. However, I just think it is a little bit risky, because later if others are writing sagemaker sensors, a not implemented error will not be thrown, and the API return values are subject to change. Please let me know what you think, and I'll change accordingly.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but most likely if the API changes, you need to change it in two places, if you forget one of the two, you have a problem. This is of course a design question. Could you at least change them to a set: {'InProgress', 'Stopping', 'Stopped'} and make them static?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added two static variables under class SageMakerHook.

non_terminal_states = {'InProgress', 'Stopping', 'Stopped'}
failed_states = {'Failed'}

And all codes in Hook and Sensors will just use these two static variables. If in the future we have an API change for all of training/tuning/inference, we only need to change this part. But if we changed API specifically for part of them, we still need to go back to previous design.

For now I would say API should be stable at this part, so used static variables.


def failed_states(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make this one static as well?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reused static variables from SageMakerHook

return ['Failed']

def get_sagemaker_response(self):
sagemaker = SageMakerHook(
aws_conn_id=self.aws_conn_id,
region_name=self.region_name
)

self.log.info('Poking Sagemaker Transform Job %s', self.job_name)
return sagemaker.describe_transform_job(self.job_name)

def get_failed_reason_from_response(self, response):
return response['FailureReason']

def state_from_response(self, response):
return response['TransformJobStatus']
Loading