Skip to content

Commit

Permalink
[AIRFLOW-2524] Update SageMaker hook and operators (apache#4091)
Browse files Browse the repository at this point in the history
This re-works the SageMaker functionality in Airflow to be more complete, and more useful for the kinds of operations that SageMaker supports.

We removed some files and operators here, but these were only added after the last release so we don't need to worry about any sort of back-compat.
  • Loading branch information
yangaws authored and Alice Berard committed Jan 3, 2019
1 parent 29ffbaa commit c5a3ca3
Show file tree
Hide file tree
Showing 28 changed files with 2,170 additions and 1,302 deletions.
22 changes: 17 additions & 5 deletions airflow/contrib/hooks/aws_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
# specific language governing permissions and limitations
# under the License.


import boto3
import configparser
import logging
Expand Down Expand Up @@ -164,17 +163,17 @@ def _get_credentials(self, region_name):
aws_session_token=aws_session_token,
region_name=region_name), endpoint_url

def get_client_type(self, client_type, region_name=None):
def get_client_type(self, client_type, region_name=None, config=None):
session, endpoint_url = self._get_credentials(region_name)

return session.client(client_type, endpoint_url=endpoint_url,
verify=self.verify)
config=config, verify=self.verify)

def get_resource_type(self, resource_type, region_name=None):
def get_resource_type(self, resource_type, region_name=None, config=None):
session, endpoint_url = self._get_credentials(region_name)

return session.resource(resource_type, endpoint_url=endpoint_url,
verify=self.verify)
config=config, verify=self.verify)

def get_session(self, region_name=None):
"""Get the underlying boto3.session."""
Expand All @@ -191,3 +190,16 @@ def get_credentials(self, region_name=None):
# secret key separately can lead to a race condition.
# See https://stackoverflow.com/a/36291428/8283373
return session.get_credentials().get_frozen_credentials()

def expand_role(self, role):
"""
Expand an IAM role name to an IAM role ARN. If role is already an IAM ARN,
no change is made.
:param role: IAM role name or ARN
:return: IAM role ARN
"""
if '/' in role:
return role
else:
return self.get_client_type('iam').get_role(RoleName=role)['Role']['Arn']
Loading

0 comments on commit c5a3ca3

Please sign in to comment.