Skip to content

Commit

Permalink
Add default 'aws_conn_id' to SageMaker Operators #21808 (#23515)
Browse files Browse the repository at this point in the history
  • Loading branch information
hsrocks authored May 9, 2022
1 parent 8280167 commit 5d1e6ff
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 14 deletions.
49 changes: 38 additions & 11 deletions airflow/providers/amazon/aws/operators/sagemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,14 @@
if TYPE_CHECKING:
from airflow.utils.context import Context

DEFAULT_CONN_ID = 'aws_default'
CHECK_INTERVAL_SECOND = 30


class SageMakerBaseOperator(BaseOperator):
"""This is the base operator for all SageMaker operators.
:param config: The configuration necessary to start a training job (templated)
:param aws_conn_id: The AWS connection ID to use.
"""

template_fields: Sequence[str] = ('config',)
Expand All @@ -48,9 +50,8 @@ class SageMakerBaseOperator(BaseOperator):
ui_color = '#ededed'
integer_fields: List[List[Any]] = []

def __init__(self, *, config: dict, aws_conn_id: str = 'aws_default', **kwargs):
def __init__(self, *, config: dict, **kwargs):
super().__init__(**kwargs)
self.aws_conn_id = aws_conn_id
self.config = config

def parse_integer(self, config, field):
Expand Down Expand Up @@ -117,6 +118,7 @@ class SageMakerProcessingOperator(SageMakerBaseOperator):
:param config: The configuration necessary to start a processing job (templated).
For details of the configuration parameter see :py:meth:`SageMaker.Client.create_processing_job`
:param aws_conn_id: The AWS connection ID to use.
:param wait_for_completion: If wait is set to True, the time interval, in seconds,
that the operation waits to check the status of the processing job.
:param print_log: if the operator should print the cloudwatch log during processing
Expand All @@ -134,9 +136,10 @@ def __init__(
self,
*,
config: dict,
aws_conn_id: str = DEFAULT_CONN_ID,
wait_for_completion: bool = True,
print_log: bool = True,
check_interval: int = 30,
check_interval: int = CHECK_INTERVAL_SECOND,
max_ingestion_time: Optional[int] = None,
action_if_job_exists: str = 'increment',
**kwargs,
Expand All @@ -148,6 +151,7 @@ def __init__(
Provided value: '{action_if_job_exists}'."
)
self.action_if_job_exists = action_if_job_exists
self.aws_conn_id = aws_conn_id
self.wait_for_completion = wait_for_completion
self.print_log = print_log
self.check_interval = check_interval
Expand Down Expand Up @@ -207,9 +211,16 @@ class SageMakerEndpointConfigOperator(SageMakerBaseOperator):

integer_fields = [['ProductionVariants', 'InitialInstanceCount']]

def __init__(self, *, config: dict, **kwargs):
def __init__(
self,
*,
config: dict,
aws_conn_id: str = DEFAULT_CONN_ID,
**kwargs,
):
super().__init__(config=config, **kwargs)
self.config = config
self.aws_conn_id = aws_conn_id

def execute(self, context: 'Context') -> dict:
self.preprocess_config()
Expand Down Expand Up @@ -265,21 +276,24 @@ class SageMakerEndpointOperator(SageMakerBaseOperator):
:param max_ingestion_time: If wait is set to True, this operation fails if the endpoint creation doesn't
finish within max_ingestion_time seconds. If you set this parameter to None it never times out.
:param operation: Whether to create an endpoint or update an endpoint. Must be either 'create or 'update'.
:param aws_conn_id: The AWS connection ID to use.
:return Dict: Returns The ARN of the endpoint created in Amazon SageMaker.
"""

def __init__(
self,
*,
config: dict,
aws_conn_id: str = DEFAULT_CONN_ID,
wait_for_completion: bool = True,
check_interval: int = 30,
check_interval: int = CHECK_INTERVAL_SECOND,
max_ingestion_time: Optional[int] = None,
operation: str = 'create',
**kwargs,
):
super().__init__(config=config, **kwargs)
self.config = config
self.aws_conn_id = aws_conn_id
self.wait_for_completion = wait_for_completion
self.check_interval = check_interval
self.max_ingestion_time = max_ingestion_time
Expand Down Expand Up @@ -375,6 +389,7 @@ class SageMakerTransformOperator(SageMakerBaseOperator):
For details of the configuration parameter of model_config, See:
:py:meth:`SageMaker.Client.create_model`
:param aws_conn_id: The AWS connection ID to use.
:param wait_for_completion: Set to True to wait until the transform job finishes.
:param check_interval: If wait is set to True, the time interval, in seconds,
that this operation waits to check the status of the transform job.
Expand All @@ -388,13 +403,15 @@ def __init__(
self,
*,
config: dict,
aws_conn_id: str = DEFAULT_CONN_ID,
wait_for_completion: bool = True,
check_interval: int = 30,
check_interval: int = CHECK_INTERVAL_SECOND,
max_ingestion_time: Optional[int] = None,
**kwargs,
):
super().__init__(config=config, **kwargs)
self.config = config
self.aws_conn_id = aws_conn_id
self.wait_for_completion = wait_for_completion
self.check_interval = check_interval
self.max_ingestion_time = max_ingestion_time
Expand Down Expand Up @@ -458,6 +475,7 @@ class SageMakerTuningOperator(SageMakerBaseOperator):
For details of the configuration parameter see
:py:meth:`SageMaker.Client.create_hyper_parameter_tuning_job`
:param aws_conn_id: The AWS connection ID to use.
:param wait_for_completion: Set to True to wait until the tuning job finishes.
:param check_interval: If wait is set to True, the time interval, in seconds,
that this operation waits to check the status of the tuning job.
Expand All @@ -479,13 +497,15 @@ def __init__(
self,
*,
config: dict,
aws_conn_id: str = DEFAULT_CONN_ID,
wait_for_completion: bool = True,
check_interval: int = 30,
check_interval: int = CHECK_INTERVAL_SECOND,
max_ingestion_time: Optional[int] = None,
**kwargs,
):
super().__init__(config=config, **kwargs)
self.config = config
self.aws_conn_id = aws_conn_id
self.wait_for_completion = wait_for_completion
self.check_interval = check_interval
self.max_ingestion_time = max_ingestion_time
Expand Down Expand Up @@ -528,12 +548,14 @@ class SageMakerModelOperator(SageMakerBaseOperator):
:param config: The configuration necessary to create a model.
For details of the configuration parameter see :py:meth:`SageMaker.Client.create_model`
:param aws_conn_id: The AWS connection ID to use.
:return Dict: Returns The ARN of the model created in Amazon SageMaker.
"""

def __init__(self, *, config, **kwargs):
def __init__(self, *, config, aws_conn_id: str = DEFAULT_CONN_ID, **kwargs):
super().__init__(config=config, **kwargs)
self.config = config
self.aws_conn_id = aws_conn_id

def expand_role(self) -> None:
if 'ExecutionRoleArn' in self.config:
Expand Down Expand Up @@ -562,6 +584,7 @@ class SageMakerTrainingOperator(SageMakerBaseOperator):
:param config: The configuration necessary to start a training job (templated).
For details of the configuration parameter see :py:meth:`SageMaker.Client.create_training_job`
:param aws_conn_id: The AWS connection ID to use.
:param wait_for_completion: If wait is set to True, the time interval, in seconds,
that the operation waits to check the status of the training job.
:param print_log: if the operator should print the cloudwatch log during training
Expand All @@ -588,15 +611,17 @@ def __init__(
self,
*,
config: dict,
aws_conn_id: str = DEFAULT_CONN_ID,
wait_for_completion: bool = True,
print_log: bool = True,
check_interval: int = 30,
check_interval: int = CHECK_INTERVAL_SECOND,
max_ingestion_time: Optional[int] = None,
check_if_job_exists: bool = True,
action_if_job_exists: str = 'increment',
**kwargs,
):
super().__init__(config=config, **kwargs)
self.aws_conn_id = aws_conn_id
self.wait_for_completion = wait_for_completion
self.print_log = print_log
self.check_interval = check_interval
Expand Down Expand Up @@ -657,11 +682,13 @@ class SageMakerDeleteModelOperator(SageMakerBaseOperator):
:param config: The configuration necessary to delete the model.
For details of the configuration parameter see :py:meth:`SageMaker.Client.delete_model`
:param aws_conn_id: The AWS connection ID to use.
"""

def __init__(self, *, config, **kwargs):
def __init__(self, *, config, aws_conn_id: str = DEFAULT_CONN_ID, **kwargs):
super().__init__(config=config, **kwargs)
self.config = config
self.aws_conn_id = aws_conn_id

def execute(self, context: 'Context') -> Any:
sagemaker_hook = SageMakerHook(aws_conn_id=self.aws_conn_id)
Expand Down
4 changes: 1 addition & 3 deletions tests/providers/amazon/aws/operators/test_sagemaker_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,7 @@

class TestSageMakerBaseOperator(unittest.TestCase):
def setUp(self):
self.sagemaker = SageMakerBaseOperator(
task_id='test_sagemaker_operator', aws_conn_id='sagemaker_test_id', config=config
)
self.sagemaker = SageMakerBaseOperator(task_id='test_sagemaker_operator', config=config)

def test_parse_integer(self):
self.sagemaker.integer_fields = [['key1'], ['key2', 'key3'], ['key2', 'key4'], ['key5', 'key6']]
Expand Down

0 comments on commit 5d1e6ff

Please sign in to comment.