Skip to content

Commit

Permalink
feat: inference instance type conditioned on training instance type (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
evakravi authored Oct 26, 2023
1 parent 866a2d9 commit 03f901f
Show file tree
Hide file tree
Showing 10 changed files with 763 additions and 5 deletions.
15 changes: 15 additions & 0 deletions src/sagemaker/instance_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def retrieve_default(
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
training_instance_type: Optional[str] = None,
) -> str:
"""Retrieves the default instance type for the model matching the given arguments.
Expand All @@ -56,6 +57,11 @@ def retrieve_default(
object, used for SageMaker interactions. If not
specified, one is created using the default AWS configuration
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
training_instance_type (str): In the case of a model fine-tuned on SageMaker, the training
instance type used for the training job that produced the fine-tuned weights.
Optionally supply this to get a inference instance type conditioned
on the training instance, to ensure compatability of training artifact to inference
instance. (Default: None).
Returns:
str: The default instance type to use for the model.
Expand All @@ -78,6 +84,7 @@ def retrieve_default(
tolerate_vulnerable_model,
tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
training_instance_type=training_instance_type,
)


Expand All @@ -89,6 +96,7 @@ def retrieve(
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
training_instance_type: Optional[str] = None,
) -> List[str]:
"""Retrieves the supported training instance types for the model matching the given arguments.
Expand All @@ -110,6 +118,12 @@ def retrieve(
object, used for SageMaker interactions. If not
specified, one is created using the default AWS configuration
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
training_instance_type (str): In the case of a model fine-tuned on SageMaker, the training
instance type used for the training job that produced the fine-tuned weights.
Optionally supply this to get a inference instance type conditioned
on the training instance, to ensure compatability of training artifact to inference
instance. (Default: None).
Returns:
list: The supported instance types to use for the model.
Expand All @@ -132,4 +146,5 @@ def retrieve(
tolerate_vulnerable_model,
tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
training_instance_type=training_instance_type,
)
46 changes: 44 additions & 2 deletions src/sagemaker/jumpstart/artifacts/instance_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def _retrieve_default_instance_type(
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
training_instance_type: Optional[str] = None,
) -> str:
"""Retrieves the default instance type for the model.
Expand All @@ -60,6 +61,11 @@ def _retrieve_default_instance_type(
object, used for SageMaker interactions. If not
specified, one is created using the default AWS configuration
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
training_instance_type (str): In the case of a model fine-tuned on SageMaker, the training
instance type used for the training job that produced the fine-tuned weights.
Optionally supply this to get a inference instance type conditioned
on the training instance, to ensure compatability of training artifact to inference
instance. (Default: None).
Returns:
str: the default instance type to use for the model or None.
Expand All @@ -82,7 +88,21 @@ def _retrieve_default_instance_type(
)

if scope == JumpStartScriptScope.INFERENCE:
default_instance_type = model_specs.default_inference_instance_type
instance_specific_default_instance_type = (
(
model_specs.training_instance_type_variants.get_instance_specific_default_inference_instance_type( # pylint: disable=C0301 # noqa: E501
training_instance_type
)
)
if training_instance_type is not None
and getattr(model_specs, "training_instance_type_variants", None) is not None
else None
)
default_instance_type = (
instance_specific_default_instance_type
if instance_specific_default_instance_type is not None
else model_specs.default_inference_instance_type
)
elif scope == JumpStartScriptScope.TRAINING:
default_instance_type = model_specs.default_training_instance_type
else:
Expand All @@ -103,6 +123,7 @@ def _retrieve_instance_types(
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
training_instance_type: Optional[str] = None,
) -> List[str]:
"""Retrieves the supported instance types for the model.
Expand All @@ -126,6 +147,11 @@ def _retrieve_instance_types(
object, used for SageMaker interactions. If not
specified, one is created using the default AWS configuration
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
training_instance_type (str): In the case of a model fine-tuned on SageMaker, the training
instance type used for the training job that produced the fine-tuned weights.
Optionally supply this to get a inference instance type conditioned
on the training instance, to ensure compatability of training artifact to inference
instance. (Default: None).
Returns:
list: the supported instance types to use for the model or None.
Expand All @@ -148,8 +174,24 @@ def _retrieve_instance_types(
)

if scope == JumpStartScriptScope.INFERENCE:
instance_types = model_specs.supported_inference_instance_types
default_instance_types = model_specs.supported_inference_instance_types or []
instance_specific_instance_types = (
model_specs.training_instance_type_variants.get_instance_specific_supported_inference_instance_types( # pylint: disable=C0301 # noqa: E501
training_instance_type
)
if training_instance_type is not None
and getattr(model_specs, "training_instance_type_variants", None) is not None
else []
)
instance_types = (
instance_specific_instance_types
if len(instance_specific_instance_types) > 0
else default_instance_types
)

elif scope == JumpStartScriptScope.TRAINING:
if training_instance_type is not None:
raise ValueError("Cannot use `training_instance_type` argument " "with training scope.")
instance_types = model_specs.supported_training_instance_types
else:
raise NotImplementedError(
Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/jumpstart/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -988,7 +988,6 @@ def deploy(
use_compiled_model (bool): Flag to select whether to use compiled
(optimized) model. (Default: False).
"""

self.orig_predictor_cls = predictor_cls

sagemaker_session = sagemaker_session or self.sagemaker_session
Expand Down Expand Up @@ -1039,6 +1038,7 @@ def deploy(
dependencies=dependencies,
git_config=git_config,
use_compiled_model=use_compiled_model,
training_instance_type=self.instance_type,
)

predictor = super(JumpStartEstimator, self).deploy(
Expand Down
4 changes: 3 additions & 1 deletion src/sagemaker/jumpstart/factory/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,7 @@ def get_deploy_kwargs(
tolerate_vulnerable_model: Optional[bool] = None,
use_compiled_model: Optional[bool] = None,
model_name: Optional[str] = None,
training_instance_type: Optional[str] = None,
) -> JumpStartEstimatorDeployKwargs:
"""Returns kwargs required to call `deploy` on `sagemaker.estimator.Estimator` object."""

Expand Down Expand Up @@ -313,7 +314,7 @@ def get_deploy_kwargs(
model_id=model_id,
model_from_estimator=True,
model_version=model_version,
instance_type=model_deploy_kwargs.instance_type, # prevent excess logging
instance_type=model_deploy_kwargs.instance_type if training_instance_type is None else None,
region=region,
image_uri=image_uri,
source_dir=source_dir,
Expand All @@ -333,6 +334,7 @@ def get_deploy_kwargs(
git_config=git_config,
tolerate_vulnerable_model=tolerate_vulnerable_model,
tolerate_deprecated_model=tolerate_deprecated_model,
training_instance_type=training_instance_type,
)

estimator_deploy_kwargs: JumpStartEstimatorDeployKwargs = JumpStartEstimatorDeployKwargs(
Expand Down
3 changes: 3 additions & 0 deletions src/sagemaker/jumpstart/factory/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ def _add_instance_type_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartM
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
sagemaker_session=kwargs.sagemaker_session,
training_instance_type=kwargs.training_instance_type,
)

if orig_instance_type is None:
Expand Down Expand Up @@ -643,6 +644,7 @@ def get_init_kwargs(
dependencies: Optional[List[str]] = None,
git_config: Optional[Dict[str, str]] = None,
model_package_arn: Optional[str] = None,
training_instance_type: Optional[str] = None,
) -> JumpStartModelInitKwargs:
"""Returns kwargs required to instantiate `sagemaker.estimator.Model` object."""

Expand Down Expand Up @@ -671,6 +673,7 @@ def get_init_kwargs(
tolerate_deprecated_model=tolerate_deprecated_model,
tolerate_vulnerable_model=tolerate_vulnerable_model,
model_package_arn=model_package_arn,
training_instance_type=training_instance_type,
)

model_init_kwargs = _add_model_version_to_kwargs(kwargs=model_init_kwargs)
Expand Down
58 changes: 58 additions & 0 deletions src/sagemaker/jumpstart/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,56 @@ def get_instance_specific_environment_variables(self, instance_type: str) -> Dic

return instance_family_environment_variables

def get_instance_specific_default_inference_instance_type(
self, instance_type: str
) -> Optional[str]:
"""Returns instance specific default inference instance type.
Returns None if a model, instance type tuple does not have instance
specific inference instance types.
"""

return self._get_instance_specific_property(
instance_type, "default_inference_instance_type"
)

def get_instance_specific_supported_inference_instance_types(
self, instance_type: str
) -> List[str]:
"""Returns instance specific supported inference instance types.
Returns empty list if a model, instance type tuple does not have instance
specific inference instance types.
"""

if self.variants is None:
return []

instance_specific_inference_instance_types: List[str] = (
self.variants.get(instance_type, {})
.get("properties", {})
.get("supported_inference_instance_types", [])
)

instance_type_family = get_instance_type_family(instance_type)

instance_family_inference_instance_types: List[str] = (
self.variants.get(instance_type_family, {})
.get("properties", {})
.get("supported_inference_instance_types", [])
if instance_type_family not in {"", None}
else []
)

return sorted(
list(
set(
instance_specific_inference_instance_types
+ instance_family_inference_instance_types
)
)
)

def get_image_uri(self, instance_type: str, region: str) -> Optional[str]:
"""Returns image uri from instance type and region.
Expand Down Expand Up @@ -971,6 +1021,7 @@ class JumpStartModelInitKwargs(JumpStartKwargs):
"dependencies",
"git_config",
"model_package_arn",
"training_instance_type",
]

SERIALIZATION_EXCLUSION_SET = {
Expand All @@ -981,6 +1032,7 @@ class JumpStartModelInitKwargs(JumpStartKwargs):
"tolerate_deprecated_model",
"region",
"model_package_arn",
"training_instance_type",
}

def __init__(
Expand Down Expand Up @@ -1009,6 +1061,7 @@ def __init__(
tolerate_vulnerable_model: Optional[bool] = None,
tolerate_deprecated_model: Optional[bool] = None,
model_package_arn: Optional[str] = None,
training_instance_type: Optional[str] = None,
) -> None:
"""Instantiates JumpStartModelInitKwargs object."""

Expand Down Expand Up @@ -1036,6 +1089,7 @@ def __init__(
self.tolerate_deprecated_model = tolerate_deprecated_model
self.tolerate_vulnerable_model = tolerate_vulnerable_model
self.model_package_arn = model_package_arn
self.training_instance_type = training_instance_type


class JumpStartModelDeployKwargs(JumpStartKwargs):
Expand Down Expand Up @@ -1065,6 +1119,7 @@ class JumpStartModelDeployKwargs(JumpStartKwargs):
"tolerate_vulnerable_model",
"tolerate_deprecated_model",
"sagemaker_session",
"training_instance_type",
]

SERIALIZATION_EXCLUSION_SET = {
Expand All @@ -1074,6 +1129,7 @@ class JumpStartModelDeployKwargs(JumpStartKwargs):
"tolerate_deprecated_model",
"tolerate_vulnerable_model",
"sagemaker_session",
"training_instance_type",
}

def __init__(
Expand Down Expand Up @@ -1101,6 +1157,7 @@ def __init__(
tolerate_deprecated_model: Optional[bool] = None,
tolerate_vulnerable_model: Optional[bool] = None,
sagemaker_session: Optional[Session] = None,
training_instance_type: Optional[str] = None,
) -> None:
"""Instantiates JumpStartModelDeployKwargs object."""

Expand All @@ -1127,6 +1184,7 @@ def __init__(
self.tolerate_vulnerable_model = tolerate_vulnerable_model
self.tolerate_deprecated_model = tolerate_deprecated_model
self.sagemaker_session = sagemaker_session
self.training_instance_type = training_instance_type


class JumpStartEstimatorInitKwargs(JumpStartKwargs):
Expand Down
Loading

0 comments on commit 03f901f

Please sign in to comment.