Skip to content

Commit

Permalink
[AIRFLOW-2524] Update SageMaker hook and operators (#4091)
Browse files Browse the repository at this point in the history
This re-works the SageMaker functionality in Airflow to be more complete, and more useful for the kinds of operations that SageMaker supports.

We removed some files and operators here, but these were only added after the last release so we don't need to worry about any sort of back-compat.
  • Loading branch information
yangaws authored and ashb committed Nov 1, 2018
1 parent 3be8ce7 commit ae39df5
Show file tree
Hide file tree
Showing 28 changed files with 2,279 additions and 1,304 deletions.
24 changes: 19 additions & 5 deletions airflow/contrib/hooks/aws_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
# specific language governing permissions and limitations
# under the License.


import boto3
import configparser
import logging
Expand Down Expand Up @@ -163,15 +162,17 @@ def _get_credentials(self, region_name):
aws_session_token=aws_session_token,
region_name=region_name), endpoint_url

def get_client_type(self, client_type, region_name=None):
def get_client_type(self, client_type, region_name=None, config=None):
session, endpoint_url = self._get_credentials(region_name)

return session.client(client_type, endpoint_url=endpoint_url)
return session.client(client_type, endpoint_url=endpoint_url,
config=config)

def get_resource_type(self, resource_type, region_name=None):
def get_resource_type(self, resource_type, region_name=None, config=None):
session, endpoint_url = self._get_credentials(region_name)

return session.resource(resource_type, endpoint_url=endpoint_url)
return session.resource(resource_type, endpoint_url=endpoint_url,
config=config)

def get_session(self, region_name=None):
"""Get the underlying boto3.session."""
Expand All @@ -188,3 +189,16 @@ def get_credentials(self, region_name=None):
# secret key separately can lead to a race condition.
# See https://stackoverflow.com/a/36291428/8283373
return session.get_credentials().get_frozen_credentials()

def expand_role(self, role):
"""
Expand an IAM role name to an IAM role ARN. If role is already an IAM ARN,
no change is made.
:param role: IAM role name or ARN
:return: IAM role ARN
"""
if '/' in role:
return role
else:
return self.get_client_type('iam').get_role(RoleName=role)['Role']['Arn']
Loading

0 comments on commit ae39df5

Please sign in to comment.