From 03f901ff4c08260dc38cd39e097a8bec6e58c68c Mon Sep 17 00:00:00 2001 From: evakravi <69981223+evakravi@users.noreply.github.com> Date: Thu, 26 Oct 2023 19:20:52 -0400 Subject: [PATCH] feat: inference instance type conditioned on training instance type (#4230) --- src/sagemaker/instance_types.py | 15 + .../jumpstart/artifacts/instance_types.py | 46 +- src/sagemaker/jumpstart/estimator.py | 2 +- src/sagemaker/jumpstart/factory/estimator.py | 4 +- src/sagemaker/jumpstart/factory/model.py | 3 + src/sagemaker/jumpstart/types.py | 58 +++ .../jumpstart/test_instance_types.py | 82 ++++ tests/unit/sagemaker/jumpstart/constants.py | 442 ++++++++++++++++++ .../jumpstart/estimator/test_estimator.py | 61 +++ tests/unit/sagemaker/jumpstart/test_types.py | 55 ++- 10 files changed, 763 insertions(+), 5 deletions(-) diff --git a/src/sagemaker/instance_types.py b/src/sagemaker/instance_types.py index 111cc51f29..0471f374ae 100644 --- a/src/sagemaker/instance_types.py +++ b/src/sagemaker/instance_types.py @@ -33,6 +33,7 @@ def retrieve_default( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + training_instance_type: Optional[str] = None, ) -> str: """Retrieves the default instance type for the model matching the given arguments. @@ -56,6 +57,11 @@ def retrieve_default( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + training_instance_type (str): In the case of a model fine-tuned on SageMaker, the training + instance type used for the training job that produced the fine-tuned weights. + Optionally supply this to get a inference instance type conditioned + on the training instance, to ensure compatability of training artifact to inference + instance. (Default: None). Returns: str: The default instance type to use for the model. @@ -78,6 +84,7 @@ def retrieve_default( tolerate_vulnerable_model, tolerate_deprecated_model, sagemaker_session=sagemaker_session, + training_instance_type=training_instance_type, ) @@ -89,6 +96,7 @@ def retrieve( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + training_instance_type: Optional[str] = None, ) -> List[str]: """Retrieves the supported training instance types for the model matching the given arguments. @@ -110,6 +118,12 @@ def retrieve( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + training_instance_type (str): In the case of a model fine-tuned on SageMaker, the training + instance type used for the training job that produced the fine-tuned weights. + Optionally supply this to get a inference instance type conditioned + on the training instance, to ensure compatability of training artifact to inference + instance. (Default: None). + Returns: list: The supported instance types to use for the model. @@ -132,4 +146,5 @@ def retrieve( tolerate_vulnerable_model, tolerate_deprecated_model, sagemaker_session=sagemaker_session, + training_instance_type=training_instance_type, ) diff --git a/src/sagemaker/jumpstart/artifacts/instance_types.py b/src/sagemaker/jumpstart/artifacts/instance_types.py index 428a33708d..38e02e3ebd 100644 --- a/src/sagemaker/jumpstart/artifacts/instance_types.py +++ b/src/sagemaker/jumpstart/artifacts/instance_types.py @@ -37,6 +37,7 @@ def _retrieve_default_instance_type( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + training_instance_type: Optional[str] = None, ) -> str: """Retrieves the default instance type for the model. @@ -60,6 +61,11 @@ def _retrieve_default_instance_type( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + training_instance_type (str): In the case of a model fine-tuned on SageMaker, the training + instance type used for the training job that produced the fine-tuned weights. + Optionally supply this to get a inference instance type conditioned + on the training instance, to ensure compatability of training artifact to inference + instance. (Default: None). Returns: str: the default instance type to use for the model or None. @@ -82,7 +88,21 @@ def _retrieve_default_instance_type( ) if scope == JumpStartScriptScope.INFERENCE: - default_instance_type = model_specs.default_inference_instance_type + instance_specific_default_instance_type = ( + ( + model_specs.training_instance_type_variants.get_instance_specific_default_inference_instance_type( # pylint: disable=C0301 # noqa: E501 + training_instance_type + ) + ) + if training_instance_type is not None + and getattr(model_specs, "training_instance_type_variants", None) is not None + else None + ) + default_instance_type = ( + instance_specific_default_instance_type + if instance_specific_default_instance_type is not None + else model_specs.default_inference_instance_type + ) elif scope == JumpStartScriptScope.TRAINING: default_instance_type = model_specs.default_training_instance_type else: @@ -103,6 +123,7 @@ def _retrieve_instance_types( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + training_instance_type: Optional[str] = None, ) -> List[str]: """Retrieves the supported instance types for the model. @@ -126,6 +147,11 @@ def _retrieve_instance_types( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + training_instance_type (str): In the case of a model fine-tuned on SageMaker, the training + instance type used for the training job that produced the fine-tuned weights. + Optionally supply this to get a inference instance type conditioned + on the training instance, to ensure compatability of training artifact to inference + instance. (Default: None). Returns: list: the supported instance types to use for the model or None. @@ -148,8 +174,24 @@ def _retrieve_instance_types( ) if scope == JumpStartScriptScope.INFERENCE: - instance_types = model_specs.supported_inference_instance_types + default_instance_types = model_specs.supported_inference_instance_types or [] + instance_specific_instance_types = ( + model_specs.training_instance_type_variants.get_instance_specific_supported_inference_instance_types( # pylint: disable=C0301 # noqa: E501 + training_instance_type + ) + if training_instance_type is not None + and getattr(model_specs, "training_instance_type_variants", None) is not None + else [] + ) + instance_types = ( + instance_specific_instance_types + if len(instance_specific_instance_types) > 0 + else default_instance_types + ) + elif scope == JumpStartScriptScope.TRAINING: + if training_instance_type is not None: + raise ValueError("Cannot use `training_instance_type` argument " "with training scope.") instance_types = model_specs.supported_training_instance_types else: raise NotImplementedError( diff --git a/src/sagemaker/jumpstart/estimator.py b/src/sagemaker/jumpstart/estimator.py index 9fde1a348a..4f7a041df0 100644 --- a/src/sagemaker/jumpstart/estimator.py +++ b/src/sagemaker/jumpstart/estimator.py @@ -988,7 +988,6 @@ def deploy( use_compiled_model (bool): Flag to select whether to use compiled (optimized) model. (Default: False). """ - self.orig_predictor_cls = predictor_cls sagemaker_session = sagemaker_session or self.sagemaker_session @@ -1039,6 +1038,7 @@ def deploy( dependencies=dependencies, git_config=git_config, use_compiled_model=use_compiled_model, + training_instance_type=self.instance_type, ) predictor = super(JumpStartEstimator, self).deploy( diff --git a/src/sagemaker/jumpstart/factory/estimator.py b/src/sagemaker/jumpstart/factory/estimator.py index 19f6a02915..1b24b714e7 100644 --- a/src/sagemaker/jumpstart/factory/estimator.py +++ b/src/sagemaker/jumpstart/factory/estimator.py @@ -280,6 +280,7 @@ def get_deploy_kwargs( tolerate_vulnerable_model: Optional[bool] = None, use_compiled_model: Optional[bool] = None, model_name: Optional[str] = None, + training_instance_type: Optional[str] = None, ) -> JumpStartEstimatorDeployKwargs: """Returns kwargs required to call `deploy` on `sagemaker.estimator.Estimator` object.""" @@ -313,7 +314,7 @@ def get_deploy_kwargs( model_id=model_id, model_from_estimator=True, model_version=model_version, - instance_type=model_deploy_kwargs.instance_type, # prevent excess logging + instance_type=model_deploy_kwargs.instance_type if training_instance_type is None else None, region=region, image_uri=image_uri, source_dir=source_dir, @@ -333,6 +334,7 @@ def get_deploy_kwargs( git_config=git_config, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, + training_instance_type=training_instance_type, ) estimator_deploy_kwargs: JumpStartEstimatorDeployKwargs = JumpStartEstimatorDeployKwargs( diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index 1d841fb39b..19605774ed 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -181,6 +181,7 @@ def _add_instance_type_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartM tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, + training_instance_type=kwargs.training_instance_type, ) if orig_instance_type is None: @@ -643,6 +644,7 @@ def get_init_kwargs( dependencies: Optional[List[str]] = None, git_config: Optional[Dict[str, str]] = None, model_package_arn: Optional[str] = None, + training_instance_type: Optional[str] = None, ) -> JumpStartModelInitKwargs: """Returns kwargs required to instantiate `sagemaker.estimator.Model` object.""" @@ -671,6 +673,7 @@ def get_init_kwargs( tolerate_deprecated_model=tolerate_deprecated_model, tolerate_vulnerable_model=tolerate_vulnerable_model, model_package_arn=model_package_arn, + training_instance_type=training_instance_type, ) model_init_kwargs = _add_model_version_to_kwargs(kwargs=model_init_kwargs) diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 84cc2e66f2..c4b51cc8b8 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -581,6 +581,56 @@ def get_instance_specific_environment_variables(self, instance_type: str) -> Dic return instance_family_environment_variables + def get_instance_specific_default_inference_instance_type( + self, instance_type: str + ) -> Optional[str]: + """Returns instance specific default inference instance type. + + Returns None if a model, instance type tuple does not have instance + specific inference instance types. + """ + + return self._get_instance_specific_property( + instance_type, "default_inference_instance_type" + ) + + def get_instance_specific_supported_inference_instance_types( + self, instance_type: str + ) -> List[str]: + """Returns instance specific supported inference instance types. + + Returns empty list if a model, instance type tuple does not have instance + specific inference instance types. + """ + + if self.variants is None: + return [] + + instance_specific_inference_instance_types: List[str] = ( + self.variants.get(instance_type, {}) + .get("properties", {}) + .get("supported_inference_instance_types", []) + ) + + instance_type_family = get_instance_type_family(instance_type) + + instance_family_inference_instance_types: List[str] = ( + self.variants.get(instance_type_family, {}) + .get("properties", {}) + .get("supported_inference_instance_types", []) + if instance_type_family not in {"", None} + else [] + ) + + return sorted( + list( + set( + instance_specific_inference_instance_types + + instance_family_inference_instance_types + ) + ) + ) + def get_image_uri(self, instance_type: str, region: str) -> Optional[str]: """Returns image uri from instance type and region. @@ -971,6 +1021,7 @@ class JumpStartModelInitKwargs(JumpStartKwargs): "dependencies", "git_config", "model_package_arn", + "training_instance_type", ] SERIALIZATION_EXCLUSION_SET = { @@ -981,6 +1032,7 @@ class JumpStartModelInitKwargs(JumpStartKwargs): "tolerate_deprecated_model", "region", "model_package_arn", + "training_instance_type", } def __init__( @@ -1009,6 +1061,7 @@ def __init__( tolerate_vulnerable_model: Optional[bool] = None, tolerate_deprecated_model: Optional[bool] = None, model_package_arn: Optional[str] = None, + training_instance_type: Optional[str] = None, ) -> None: """Instantiates JumpStartModelInitKwargs object.""" @@ -1036,6 +1089,7 @@ def __init__( self.tolerate_deprecated_model = tolerate_deprecated_model self.tolerate_vulnerable_model = tolerate_vulnerable_model self.model_package_arn = model_package_arn + self.training_instance_type = training_instance_type class JumpStartModelDeployKwargs(JumpStartKwargs): @@ -1065,6 +1119,7 @@ class JumpStartModelDeployKwargs(JumpStartKwargs): "tolerate_vulnerable_model", "tolerate_deprecated_model", "sagemaker_session", + "training_instance_type", ] SERIALIZATION_EXCLUSION_SET = { @@ -1074,6 +1129,7 @@ class JumpStartModelDeployKwargs(JumpStartKwargs): "tolerate_deprecated_model", "tolerate_vulnerable_model", "sagemaker_session", + "training_instance_type", } def __init__( @@ -1101,6 +1157,7 @@ def __init__( tolerate_deprecated_model: Optional[bool] = None, tolerate_vulnerable_model: Optional[bool] = None, sagemaker_session: Optional[Session] = None, + training_instance_type: Optional[str] = None, ) -> None: """Instantiates JumpStartModelDeployKwargs object.""" @@ -1127,6 +1184,7 @@ def __init__( self.tolerate_vulnerable_model = tolerate_vulnerable_model self.tolerate_deprecated_model = tolerate_deprecated_model self.sagemaker_session = sagemaker_session + self.training_instance_type = training_instance_type class JumpStartEstimatorInitKwargs(JumpStartKwargs): diff --git a/tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py b/tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py index f13121aa94..bed2e50674 100644 --- a/tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py +++ b/tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py @@ -163,6 +163,88 @@ def test_jumpstart_instance_types(patched_get_model_specs): instance_types.retrieve(model_id=model_id, scope="training") +@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") +def test_jumpstart_inference_instance_type_variants(patched_get_model_specs): + patched_get_model_specs.side_effect = get_special_model_spec + + mock_client = boto3.client("s3") + mock_session = Mock(s3_client=mock_client) + model_id, model_version = "inference-instance-types-variant-model", "*" + region = "us-west-2" + + assert ["ml.inf1.2xlarge", "ml.inf1.xlarge"] == instance_types.retrieve( + region=region, + model_id=model_id, + model_version=model_version, + scope="inference", + sagemaker_session=mock_session, + training_instance_type="ml.trn1.xlarge", + ) + + assert ["ml.inf1.2xlarge", "ml.inf1.xlarge"] == instance_types.retrieve( + region=region, + model_id=model_id, + model_version=model_version, + scope="inference", + sagemaker_session=mock_session, + training_instance_type="ml.trn1.12xlarge", + ) + + assert ["ml.p2.xlarge", "ml.p3.xlarge", "ml.p5.xlarge"] == instance_types.retrieve( + region=region, + model_id=model_id, + model_version=model_version, + scope="inference", + sagemaker_session=mock_session, + training_instance_type="ml.p2.12xlarge", + ) + + assert ["ml.p4de.24xlarge"] == instance_types.retrieve( + region=region, + model_id=model_id, + model_version=model_version, + scope="inference", + sagemaker_session=mock_session, + training_instance_type="ml.p29s.12xlarge", + ) + + assert "ml.inf1.xlarge" == instance_types.retrieve_default( + region=region, + model_id=model_id, + model_version=model_version, + scope="inference", + sagemaker_session=mock_session, + training_instance_type="ml.trn1.xlarge", + ) + + assert "ml.inf1.xlarge" == instance_types.retrieve_default( + region=region, + model_id=model_id, + model_version=model_version, + scope="inference", + sagemaker_session=mock_session, + training_instance_type="ml.trn1.12xlarge", + ) + + assert "ml.p5.xlarge" == instance_types.retrieve_default( + region=region, + model_id=model_id, + model_version=model_version, + scope="inference", + sagemaker_session=mock_session, + training_instance_type="ml.p2.12xlarge", + ) + + assert "ml.p4de.24xlarge" == instance_types.retrieve_default( + region=region, + model_id=model_id, + model_version=model_version, + scope="inference", + sagemaker_session=mock_session, + training_instance_type="ml.p29s.12xlarge", + ) + + @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_jumpstart_no_supported_instance_types(patched_get_model_specs): patched_get_model_specs.side_effect = get_special_model_spec diff --git a/tests/unit/sagemaker/jumpstart/constants.py b/tests/unit/sagemaker/jumpstart/constants.py index 1cd6c92493..6551497318 100644 --- a/tests/unit/sagemaker/jumpstart/constants.py +++ b/tests/unit/sagemaker/jumpstart/constants.py @@ -174,6 +174,448 @@ }, }, }, + "inference-instance-types-variant-model": { + "model_id": "huggingface-llm-falcon-180b-bf16", + "url": "https://huggingface.co/tiiuae/falcon-180B", + "version": "1.0.0", + "min_sdk_version": "2.175.0", + "training_supported": True, + "incremental_training_supported": False, + "hosting_ecr_specs": { + "framework": "huggingface-llm", + "framework_version": "0.9.3", + "py_version": "py39", + "huggingface_transformers_version": "4.29.2", + }, + "hosting_artifact_key": "huggingface-infer/infer-huggingface-llm-falcon-180b-bf16.tar.gz", + "hosting_script_key": "source-directory-tarballs/huggingface/inference/llm/v1.0.1/sourcedir.tar.gz", + "hosting_prepacked_artifact_key": "huggingface-infer/prepack/v1.0.1/infer-prepack" + "-huggingface-llm-falcon-180b-bf16.tar.gz", + "hosting_prepacked_artifact_version": "1.0.1", + "hosting_use_script_uri": False, + "inference_vulnerable": False, + "inference_dependencies": [], + "inference_vulnerabilities": [], + "training_vulnerable": False, + "training_dependencies": [], + "training_vulnerabilities": [], + "deprecated": False, + "inference_environment_variables": [ + { + "name": "SAGEMAKER_PROGRAM", + "type": "text", + "default": "inference.py", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_SUBMIT_DIRECTORY", + "type": "text", + "default": "/opt/ml/model/code", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "SAGEMAKER_CONTAINER_LOG_LEVEL", + "type": "text", + "default": "20", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "SAGEMAKER_MODEL_SERVER_TIMEOUT", + "type": "text", + "default": "3600", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "ENDPOINT_SERVER_TIMEOUT", + "type": "int", + "default": 3600, + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "MODEL_CACHE_ROOT", + "type": "text", + "default": "/opt/ml/model", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_ENV", + "type": "text", + "default": "1", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "HF_MODEL_ID", + "type": "text", + "default": "/opt/ml/model", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SM_NUM_GPUS", + "type": "text", + "default": "8", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "MAX_INPUT_LENGTH", + "type": "text", + "default": "1024", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "MAX_TOTAL_TOKENS", + "type": "text", + "default": "2048", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_MODEL_SERVER_WORKERS", + "type": "int", + "default": 1, + "scope": "container", + "required_for_model_class": True, + }, + ], + "metrics": [], + "default_inference_instance_type": "ml.p4de.24xlarge", + "supported_inference_instance_types": ["ml.p4de.24xlarge"], + "default_training_instance_type": "ml.p4de.24xlarge", + "supported_training_instance_types": ["ml.p4de.24xlarge"], + "model_kwargs": {}, + "deploy_kwargs": { + "model_data_download_timeout": 3600, + "container_startup_health_check_timeout": 3600, + }, + "predictor_specs": { + "supported_content_types": ["application/json"], + "supported_accept_types": ["application/json"], + "default_content_type": "application/json", + "default_accept_type": "application/json", + }, + "inference_volume_size": 512, + "inference_enable_network_isolation": True, + "validation_supported": False, + "fine_tuning_supported": False, + "resource_name_base": "hf-llm-falcon-180b-bf16", + "training_instance_type_variants": { + "regional_aliases": { + "us-west-2": { + "gpu_image_uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" + "huggingface-pytorch-inference:1.13.1-transformers4.26.0-gpu-py39-cu117-ubuntu20.04", + "gpu_image_uri_2": "763104351884.dkr.ecr.us-west-2.amazonaws.com/stud-gpu", + "cpu_image_uri": "867930986793.dkr.us-west-2.amazonaws.com/cpu-blah", + } + }, + "variants": { + "ml.p2.12xlarge": { + "properties": { + "environment_variables": {"TENSOR_PARALLEL_DEGREE": "4"}, + "supported_inference_instance_types": ["ml.p5.xlarge"], + "default_inference_instance_type": "ml.p5.xlarge", + "metrics": [ + { + "Name": "huggingface-textgeneration:eval-loss", + "Regex": "'eval_loss': ([0-9]+\\.[0-9]+)", + }, + { + "Name": "huggingface-textgeneration:instance-typemetric-loss", + "Regex": "'eval_loss': ([0-9]+\\.[0-9]+)", + }, + { + "Name": "huggingface-textgeneration:train-loss", + "Regex": "'instance type specific': ([0-9]+\\.[0-9]+)", + }, + { + "Name": "huggingface-textgeneration:noneyourbusiness-loss", + "Regex": "'loss-noyb instance specific': ([0-9]+\\.[0-9]+)", + }, + ], + } + }, + "p2": { + "regional_properties": {"image_uri": "$gpu_image_uri"}, + "properties": { + "supported_inference_instance_types": ["ml.p2.xlarge", "ml.p3.xlarge"], + "default_inference_instance_type": "ml.p2.xlarge", + "metrics": [ + { + "Name": "huggingface-textgeneration:wtafigo", + "Regex": "'evasadfasdl_loss': ([0-9]+\\.[0-9]+)", + }, + { + "Name": "huggingface-textgeneration:eval-loss", + "Regex": "'eval_loss': ([0-9]+\\.[0-9]+)", + }, + { + "Name": "huggingface-textgeneration:train-loss", + "Regex": "'instance family specific': ([0-9]+\\.[0-9]+)", + }, + { + "Name": "huggingface-textgeneration:noneyourbusiness-loss", + "Regex": "'loss-noyb': ([0-9]+\\.[0-9]+)", + }, + ], + }, + }, + "p3": {"regional_properties": {"image_uri": "$gpu_image_uri"}}, + "ml.p3.200xlarge": {"regional_properties": {"image_uri": "$gpu_image_uri_2"}}, + "p4": { + "regional_properties": {"image_uri": "$gpu_image_uri"}, + "properties": { + "prepacked_artifact_key": "path/to/prepacked/inference/artifact/prefix/number2/" + }, + }, + "g4": { + "regional_properties": {"image_uri": "$gpu_image_uri"}, + "properties": { + "artifact_key": "path/to/prepacked/training/artifact/prefix/number2/" + }, + }, + "g4dn": {"regional_properties": {"image_uri": "$gpu_image_uri"}}, + "g9": { + "regional_properties": {"image_uri": "$gpu_image_uri"}, + "properties": { + "prepacked_artifact_key": "asfs/adsf/sda/f", + "hyperparameters": [ + { + "name": "num_bag_sets", + "type": "int", + "default": 5, + "min": 5, + "scope": "algorithm", + }, + { + "name": "num_stack_levels", + "type": "int", + "default": 6, + "min": 7, + "max": 3, + "scope": "algorithm", + }, + { + "name": "refit_full", + "type": "text", + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "set_best_to_refit_full", + "type": "text", + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "save_space", + "type": "text", + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "verbosity", + "type": "int", + "default": 2, + "min": 0, + "max": 4, + "scope": "algorithm", + }, + { + "name": "sagemaker_submit_directory", + "type": "text", + "default": "/opt/ml/input/data/code/sourcedir.tar.gz", + "scope": "container", + }, + { + "name": "sagemaker_program", + "type": "text", + "default": "transfer_learning.py", + "scope": "container", + }, + { + "name": "sagemaker_container_log_level", + "type": "text", + "default": "20", + "scope": "container", + }, + ], + }, + }, + "p9": { + "regional_properties": {"image_uri": "$gpu_image_uri"}, + "properties": {"artifact_key": "do/re/mi"}, + }, + "m2": { + "regional_properties": {"image_uri": "$cpu_image_uri"}, + "properties": {"environment_variables": {"TENSOR_PARALLEL_DEGREE": "400"}}, + }, + "c2": {"regional_properties": {"image_uri": "$cpu_image_uri"}}, + "local": {"regional_properties": {"image_uri": "$cpu_image_uri"}}, + "ml.g5.48xlarge": { + "properties": {"environment_variables": {"TENSOR_PARALLEL_DEGREE": "8"}} + }, + "ml.g5.12xlarge": { + "properties": {"environment_variables": {"TENSOR_PARALLEL_DEGREE": "4"}} + }, + "g5": { + "properties": { + "environment_variables": {"TENSOR_PARALLEL_DEGREE": "4", "JOHN": "DOE"} + } + }, + "ml.g9.12xlarge": { + "properties": { + "environment_variables": {"TENSOR_PARALLEL_DEGREE": "4"}, + "prepacked_artifact_key": "nlahdasf/asdf/asd/f", + "hyperparameters": [ + { + "name": "eval_metric", + "type": "text", + "default": "auto", + "scope": "algorithm", + }, + { + "name": "presets", + "type": "text", + "default": "medium_quality", + "options": [ + "best_quality", + "high_quality", + "good_quality", + "medium_quality", + "optimize_for_deployment", + "interpretable", + ], + "scope": "algorithm", + }, + { + "name": "auto_stack", + "type": "text", + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "num_bag_folds", + "type": "text", + "default": "0", + "options": ["0", "2", "3", "4", "5", "6", "7", "8", "9", "10"], + "scope": "algorithm", + }, + { + "name": "num_bag_sets", + "type": "int", + "default": 1, + "min": 1, + "scope": "algorithm", + }, + { + "name": "num_stack_levels", + "type": "int", + "default": 0, + "min": 0, + "max": 3, + "scope": "algorithm", + }, + ], + } + }, + "ml.p9.12xlarge": { + "properties": { + "environment_variables": {"TENSOR_PARALLEL_DEGREE": "4"}, + "artifact_key": "you/not/entertained", + } + }, + "g6": { + "properties": { + "environment_variables": {"BLAH": "4"}, + "artifact_key": "path/to/training/artifact.tar.gz", + "prepacked_artifact_key": "path/to/prepacked/inference/artifact/prefix/", + } + }, + "trn1": { + "properties": { + "supported_inference_instance_types": ["ml.inf1.xlarge", "ml.inf1.2xlarge"], + "default_inference_instance_type": "ml.inf1.xlarge", + } + }, + }, + }, + "training_ecr_specs": { + "framework": "pytorch", + "framework_version": "1.5.0", + "py_version": "py3", + }, + "training_artifact_key": "pytorch-training/train-pytorch-ic-mobilenet-v2.tar.gz", + "training_script_key": "source-directory-tarballs/pytorch/transfer_learning/ic/v1.0.0/sourcedir.tar.gz", + "training_prepacked_script_key": None, + "training_model_package_artifact_uris": None, + "deprecate_warn_message": None, + "deprecated_message": None, + "hosting_eula_key": None, + "hyperparameters": [ + { + "name": "epochs", + "type": "int", + "default": 3, + "min": 1, + "max": 1000, + "scope": "algorithm", + }, + { + "name": "adam-learning-rate", + "type": "float", + "default": 0.05, + "min": 1e-08, + "max": 1, + "scope": "algorithm", + }, + { + "name": "batch-size", + "type": "int", + "default": 4, + "min": 1, + "max": 1024, + "scope": "algorithm", + }, + { + "name": "sagemaker_submit_directory", + "type": "text", + "default": "/opt/ml/input/data/code/sourcedir.tar.gz", + "scope": "container", + }, + { + "name": "sagemaker_program", + "type": "text", + "default": "transfer_learning.py", + "scope": "container", + }, + { + "name": "sagemaker_container_log_level", + "type": "text", + "default": "20", + "scope": "container", + }, + ], + "training_vulnerable": False, + "deprecated": False, + "estimator_kwargs": { + "encrypt_inter_container_traffic": True, + }, + "training_volume_size": 456, + "inference_enable_network_isolation": True, + "training_enable_network_isolation": False, + }, "variant-model": { "model_id": "pytorch-ic-mobilenet-v2", "url": "https://pytorch.org/hub/pytorch_vision_mobilenet_v2/", diff --git a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py index bc84bc6442..6f4788fa04 100644 --- a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py +++ b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py @@ -1126,6 +1126,67 @@ def test_incremental_training_with_supported_model_doesnt_log_warning( sagemaker_session=sagemaker_session, ) + @mock.patch("sagemaker.utils.sagemaker_timestamp") + @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.factory.model.Session") + @mock.patch("sagemaker.jumpstart.factory.estimator.Session") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") + @mock.patch("sagemaker.jumpstart.estimator.Estimator.fit") + @mock.patch("sagemaker.jumpstart.estimator.Estimator.deploy") + @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_DEFAULT_REGION_NAME", region) + @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_estimator_sets_different_inference_instance_depending_on_training_instance( + self, + mock_estimator_deploy: mock.Mock, + mock_estimator_fit: mock.Mock, + mock_estimator_init: mock.Mock, + mock_get_model_specs: mock.Mock, + mock_session_estimator: mock.Mock, + mock_session_model: mock.Mock, + mock_is_valid_model_id: mock.Mock, + mock_sagemaker_timestamp: mock.Mock, + ): + mock_is_valid_model_id.return_value = True + + mock_sagemaker_timestamp.return_value = "3456" + + mock_estimator_deploy.return_value = default_predictor + + model_id = "inference-instance-types-variant-model" + + mock_get_model_specs.side_effect = get_special_model_spec + + mock_session_estimator.return_value = sagemaker_session + mock_session_model.return_value = sagemaker_session + + estimator = JumpStartEstimator( + model_id=model_id, image_uri="blah", instance_type="ml.trn1.xlarge" + ) + estimator.deploy(image_uri="blah") + assert mock_estimator_deploy.call_args[1]["instance_type"] == "ml.inf1.xlarge" + mock_estimator_deploy.reset_mock() + + estimator = JumpStartEstimator( + model_id=model_id, image_uri="blah", instance_type="ml.p2.xlarge" + ) + estimator.deploy(image_uri="blah") + assert mock_estimator_deploy.call_args[1]["instance_type"] == "ml.p2.xlarge" + mock_estimator_deploy.reset_mock() + + estimator = JumpStartEstimator( + model_id=model_id, image_uri="blah", instance_type="ml.p2.12xlarge" + ) + estimator.deploy(image_uri="blah") + assert mock_estimator_deploy.call_args[1]["instance_type"] == "ml.p5.xlarge" + mock_estimator_deploy.reset_mock() + + estimator = JumpStartEstimator( + model_id=model_id, image_uri="blah", instance_type="ml.blah.xblah" + ) + estimator.deploy(image_uri="blah") + assert mock_estimator_deploy.call_args[1]["instance_type"] == "ml.p4de.24xlarge" + @mock.patch("sagemaker.utils.sagemaker_timestamp") @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") @mock.patch("sagemaker.jumpstart.factory.model.Session") diff --git a/tests/unit/sagemaker/jumpstart/test_types.py b/tests/unit/sagemaker/jumpstart/test_types.py index 4aff263e96..e269eab5a3 100644 --- a/tests/unit/sagemaker/jumpstart/test_types.py +++ b/tests/unit/sagemaker/jumpstart/test_types.py @@ -35,6 +35,8 @@ "ml.p2.12xlarge": { "properties": { "environment_variables": {"TENSOR_PARALLEL_DEGREE": "4"}, + "supported_inference_instance_types": ["ml.p5.xlarge"], + "default_inference_instance_type": "ml.p5.xlarge", "metrics": [ { "Name": "huggingface-textgeneration:eval-loss", @@ -58,6 +60,8 @@ "p2": { "regional_properties": {"image_uri": "$gpu_image_uri"}, "properties": { + "supported_inference_instance_types": ["ml.p2.xlarge", "ml.p3.xlarge"], + "default_inference_instance_type": "ml.p2.xlarge", "metrics": [ { "Name": "huggingface-textgeneration:wtafigo", @@ -75,7 +79,7 @@ "Name": "huggingface-textgeneration:noneyourbusiness-loss", "Regex": "'loss-noyb': ([0-9]+\\.[0-9]+)", }, - ] + ], }, }, "p3": {"regional_properties": {"image_uri": "$gpu_image_uri"}}, @@ -254,6 +258,12 @@ "prepacked_artifact_key": "path/to/prepacked/inference/artifact/prefix/", } }, + "trn1": { + "properties": { + "supported_inference_instance_types": ["ml.inf1.xlarge", "ml.inf1.2xlarge"], + "default_inference_instance_type": "ml.inf1.xlarge", + } + }, }, } ) @@ -665,6 +675,49 @@ def test_jumpstart_hyperparameter_instance_variants(): assert hyperparams == [] +def test_jumpstart_inference_instance_type_variants(): + assert INSTANCE_TYPE_VARIANT.get_instance_specific_supported_inference_instance_types( + "ml.p2.xlarge" + ) == ["ml.p2.xlarge", "ml.p3.xlarge"] + assert ( + INSTANCE_TYPE_VARIANT.get_instance_specific_default_inference_instance_type("ml.p2.2xlarge") + == "ml.p2.xlarge" + ) + + assert INSTANCE_TYPE_VARIANT.get_instance_specific_supported_inference_instance_types( + "ml.p2.12xlarge" + ) == ["ml.p2.xlarge", "ml.p3.xlarge", "ml.p5.xlarge"] + assert ( + INSTANCE_TYPE_VARIANT.get_instance_specific_default_inference_instance_type( + "ml.p2.12xlarge" + ) + == "ml.p5.xlarge" + ) + + assert ( + INSTANCE_TYPE_VARIANT.get_instance_specific_supported_inference_instance_types( + "ml.sdfsad.12xlarge" + ) + == [] + ) + assert ( + INSTANCE_TYPE_VARIANT.get_instance_specific_default_inference_instance_type( + "ml.adfas.12xlarge" + ) + is None + ) + + assert INSTANCE_TYPE_VARIANT.get_instance_specific_supported_inference_instance_types( + "ml.trn1.12xlarge" + ) == ["ml.inf1.2xlarge", "ml.inf1.xlarge"] + assert ( + INSTANCE_TYPE_VARIANT.get_instance_specific_default_inference_instance_type( + "ml.trn1.12xlarge" + ) + == "ml.inf1.xlarge" + ) + + def test_jumpstart_environment_variables_instance_variants(): assert INSTANCE_TYPE_VARIANT.get_instance_specific_environment_variables( instance_type="ml.g9.12xlarge"