diff --git a/src/sagemaker/jumpstart/model.py b/src/sagemaker/jumpstart/model.py index 8b1badb94b..619af2f7a9 100644 --- a/src/sagemaker/jumpstart/model.py +++ b/src/sagemaker/jumpstart/model.py @@ -41,18 +41,17 @@ from sagemaker.jumpstart.types import ( JumpStartSerializablePayload, DeploymentConfigMetadata, - JumpStartBenchmarkStat, - JumpStartMetadataConfig, ) from sagemaker.jumpstart.utils import ( validate_model_id_and_get_type, verify_model_region_and_return_specs, get_jumpstart_configs, get_metrics_from_deployment_configs, + add_instance_rate_stats_to_benchmark_metrics, ) from sagemaker.jumpstart.constants import JUMPSTART_LOGGER from sagemaker.jumpstart.enums import JumpStartModelType -from sagemaker.utils import stringify_object, format_tags, Tags, get_instance_rate_per_hour +from sagemaker.utils import stringify_object, format_tags, Tags from sagemaker.model import ( Model, ModelPackage, @@ -361,17 +360,13 @@ def _validate_model_id_and_type(): self.model_package_arn = model_init_kwargs.model_package_arn self.init_kwargs = model_init_kwargs.to_kwargs_dict(False) - metadata_configs = get_jumpstart_configs( + self._metadata_configs = get_jumpstart_configs( region=self.region, model_id=self.model_id, model_version=self.model_version, sagemaker_session=self.sagemaker_session, model_type=self.model_type, ) - self._deployment_configs = [ - self._convert_to_deployment_config_metadata(config_name, config) - for config_name, config in metadata_configs.items() - ] def log_subscription_warning(self) -> None: """Log message prompting the customer to subscribe to the proprietary model.""" @@ -449,25 +444,33 @@ def set_deployment_config(self, config_name: str, instance_type: str) -> None: @property def deployment_config(self) -> Optional[Dict[str, Any]]: - """The deployment config that will be applied to the model. + """The deployment config that will be applied to ``This`` model. Returns: - Optional[Dict[str, Any]]: Deployment config that will be applied to the model. + Optional[Dict[str, Any]]: Deployment config. """ - return self._retrieve_selected_deployment_config(self.config_name) + deployment_config = self._retrieve_selected_deployment_config( + self.config_name, self.instance_type + ) + return deployment_config.to_json() if deployment_config is not None else None @property def benchmark_metrics(self) -> pd.DataFrame: - """Benchmark Metrics for deployment configs + """Benchmark Metrics for deployment configs. Returns: - Metrics: Pandas DataFrame object. + Benchmark Metrics: Pandas DataFrame object. """ - return pd.DataFrame(self._get_benchmarks_data(self.config_name)) + benchmark_metrics_data = self._get_deployment_configs_benchmarks_data( + self.config_name, self.instance_type + ) + keys = list(benchmark_metrics_data.keys()) + df = pd.DataFrame(benchmark_metrics_data).sort_values(by=[keys[0], keys[1]]) + return df def display_benchmark_metrics(self) -> None: - """Display Benchmark Metrics for deployment configs.""" - print(self.benchmark_metrics.to_markdown()) + """Display deployment configs benchmark metrics.""" + print(self.benchmark_metrics.to_markdown(index=False)) def list_deployment_configs(self) -> List[Dict[str, Any]]: """List deployment configs for ``This`` model. @@ -475,7 +478,12 @@ def list_deployment_configs(self) -> List[Dict[str, Any]]: Returns: List[Dict[str, Any]]: A list of deployment configs. """ - return self._deployment_configs + return [ + deployment_config.to_json() + for deployment_config in self._get_deployment_configs( + self.config_name, self.instance_type + ) + ] def _create_sagemaker_model( self, @@ -866,92 +874,94 @@ def register_deploy_wrapper(*args, **kwargs): return model_package @lru_cache - def _get_benchmarks_data(self, config_name: str) -> Dict[str, List[str]]: + def _get_deployment_configs_benchmarks_data( + self, config_name: str, instance_type: str + ) -> Dict[str, Any]: """Deployment configs benchmark metrics. Args: - config_name (str): The name of the selected deployment config. + config_name (str): Name of selected deployment config. + instance_type (str): The selected Instance type. Returns: Dict[str, List[str]]: Deployment config benchmark data. """ return get_metrics_from_deployment_configs( - self._deployment_configs, - config_name, + self._get_deployment_configs(config_name, instance_type) ) @lru_cache - def _retrieve_selected_deployment_config(self, config_name: str) -> Optional[Dict[str, Any]]: - """Retrieve the deployment config to apply to the model. + def _retrieve_selected_deployment_config( + self, config_name: str, instance_type: str + ) -> Optional[DeploymentConfigMetadata]: + """Retrieve the deployment config to apply to `This` model. Args: config_name (str): The name of the deployment config to retrieve. + instance_type (str): The instance type of the deployment config to retrieve. Returns: Optional[Dict[str, Any]]: The retrieved deployment config. """ if config_name is None: return None - for deployment_config in self._deployment_configs: - if deployment_config.get("DeploymentConfigName") == config_name: + for deployment_config in self._get_deployment_configs(config_name, instance_type): + if deployment_config.deployment_config_name == config_name: return deployment_config return None - def _convert_to_deployment_config_metadata( - self, config_name: str, metadata_config: JumpStartMetadataConfig - ) -> Dict[str, Any]: - """Retrieve deployment config for config name. + @lru_cache + def _get_deployment_configs( + self, selected_config_name: str, selected_instance_type: str + ) -> List[DeploymentConfigMetadata]: + """Retrieve deployment configs metadata. Args: - config_name (str): Name of deployment config. - metadata_config (JumpStartMetadataConfig): Metadata config for deployment config. - Returns: - A deployment metadata config for config name (dict[str, Any]). + selected_config_name (str): The name of the selected deployment config. + selected_instance_type (str): The selected instance type. """ - default_inference_instance_type = metadata_config.resolved_config.get( - "default_inference_instance_type" - ) - - benchmark_metrics = ( - metadata_config.benchmark_metrics.get(default_inference_instance_type) - if metadata_config.benchmark_metrics is not None - else None - ) - - should_fetch_instance_rate_metric = True - if benchmark_metrics is not None: - for benchmark_metric in benchmark_metrics: - if benchmark_metric.name.lower() == "instance rate": - should_fetch_instance_rate_metric = False - break - - if should_fetch_instance_rate_metric: - instance_rate = get_instance_rate_per_hour( - instance_type=default_inference_instance_type, region=self.region + deployment_configs = [] + if self._metadata_configs is None: + return deployment_configs + + err = None + for config_name, metadata_config in self._metadata_configs.items(): + if err is None or "is not authorized to perform: pricing:GetProducts" not in err: + err, metadata_config.benchmark_metrics = ( + add_instance_rate_stats_to_benchmark_metrics( + self.region, metadata_config.benchmark_metrics + ) + ) + + resolved_config = metadata_config.resolved_config + if selected_config_name == config_name: + instance_type_to_use = selected_instance_type + else: + instance_type_to_use = resolved_config.get("default_inference_instance_type") + + init_kwargs = get_init_kwargs( + model_id=self.model_id, + instance_type=instance_type_to_use, + sagemaker_session=self.sagemaker_session, ) - if instance_rate is not None: - instance_rate_metric = JumpStartBenchmarkStat(instance_rate) - - if benchmark_metrics is None: - benchmark_metrics = [instance_rate_metric] - else: - benchmark_metrics.append(instance_rate_metric) - - init_kwargs = get_init_kwargs( - model_id=self.model_id, - instance_type=default_inference_instance_type, - sagemaker_session=self.sagemaker_session, - ) - deploy_kwargs = get_deploy_kwargs( - model_id=self.model_id, - instance_type=default_inference_instance_type, - sagemaker_session=self.sagemaker_session, - ) + deploy_kwargs = get_deploy_kwargs( + model_id=self.model_id, + instance_type=instance_type_to_use, + sagemaker_session=self.sagemaker_session, + ) + deployment_config_metadata = DeploymentConfigMetadata( + config_name, + metadata_config.benchmark_metrics, + resolved_config, + init_kwargs, + deploy_kwargs, + ) + deployment_configs.append(deployment_config_metadata) - deployment_config_metadata = DeploymentConfigMetadata( - config_name, benchmark_metrics, init_kwargs, deploy_kwargs - ) + if err is not None and "is not authorized to perform: pricing:GetProducts" in err: + error_message = "Instance rate metrics will be omitted. Reason: %s" + JUMPSTART_LOGGER.warning(error_message, err) - return deployment_config_metadata.to_json() + return deployment_configs def __str__(self) -> str: """Overriding str(*) method to make more human-readable.""" diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index cd74a03e5a..e0a0f9bea7 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -2235,29 +2235,37 @@ def to_json(self) -> Dict[str, Any]: if hasattr(self, att): cur_val = getattr(self, att) att = self._convert_to_pascal_case(att) - if issubclass(type(cur_val), JumpStartDataHolderType): - json_obj[att] = cur_val.to_json() - elif isinstance(cur_val, list): - json_obj[att] = [] - for obj in cur_val: - if issubclass(type(obj), JumpStartDataHolderType): - json_obj[att].append(obj.to_json()) - else: - json_obj[att].append(obj) - elif isinstance(cur_val, dict): - json_obj[att] = {} - for key, val in cur_val.items(): - if issubclass(type(val), JumpStartDataHolderType): - json_obj[att][self._convert_to_pascal_case(key)] = val.to_json() - else: - json_obj[att][key] = val - else: - json_obj[att] = cur_val + json_obj[att] = self._val_to_json(cur_val) return json_obj + def _val_to_json(self, val: Any) -> Any: + """Converts the given value to JSON. + + Args: + val (Any): The value to convert. + Returns: + Any: The converted json value. + """ + if issubclass(type(val), JumpStartDataHolderType): + return val.to_json() + if isinstance(val, list): + list_obj = [] + for obj in val: + list_obj.append(self._val_to_json(obj)) + return list_obj + if isinstance(val, dict): + dict_obj = {} + for k, v in val.items(): + if isinstance(v, JumpStartDataHolderType): + dict_obj[self._convert_to_pascal_case(k)] = self._val_to_json(v) + else: + dict_obj[k] = self._val_to_json(v) + return dict_obj + return val + class DeploymentArgs(BaseDeploymentConfigDataHolder): - """Dataclass representing a Deployment Config.""" + """Dataclass representing a Deployment Args.""" __slots__ = [ "image_uri", @@ -2270,9 +2278,12 @@ class DeploymentArgs(BaseDeploymentConfigDataHolder): ] def __init__( - self, init_kwargs: JumpStartModelInitKwargs, deploy_kwargs: JumpStartModelDeployKwargs + self, + init_kwargs: Optional[JumpStartModelInitKwargs] = None, + deploy_kwargs: Optional[JumpStartModelDeployKwargs] = None, + resolved_config: Optional[Dict[str, Any]] = None, ): - """Instantiates DeploymentConfig object.""" + """Instantiates DeploymentArgs object.""" if init_kwargs is not None: self.image_uri = init_kwargs.image_uri self.model_data = init_kwargs.model_data @@ -2287,6 +2298,11 @@ def __init__( self.container_startup_health_check_timeout = ( deploy_kwargs.container_startup_health_check_timeout ) + if resolved_config is not None: + self.default_instance_type = resolved_config.get("default_inference_instance_type") + self.supported_instance_types = resolved_config.get( + "supported_inference_instance_types" + ) class DeploymentConfigMetadata(BaseDeploymentConfigDataHolder): @@ -2301,13 +2317,15 @@ class DeploymentConfigMetadata(BaseDeploymentConfigDataHolder): def __init__( self, - config_name: str, - benchmark_metrics: List[JumpStartBenchmarkStat], - init_kwargs: JumpStartModelInitKwargs, - deploy_kwargs: JumpStartModelDeployKwargs, + config_name: Optional[str] = None, + benchmark_metrics: Optional[Dict[str, List[JumpStartBenchmarkStat]]] = None, + resolved_config: Optional[Dict[str, Any]] = None, + init_kwargs: Optional[JumpStartModelInitKwargs] = None, + deploy_kwargs: Optional[JumpStartModelDeployKwargs] = None, ): """Instantiates DeploymentConfigMetadata object.""" self.deployment_config_name = config_name - self.deployment_args = DeploymentArgs(init_kwargs, deploy_kwargs) - self.acceleration_configs = None + self.deployment_args = DeploymentArgs(init_kwargs, deploy_kwargs, resolved_config) self.benchmark_metrics = benchmark_metrics + if resolved_config is not None: + self.acceleration_configs = resolved_config.get("acceleration_configs") diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index 357bdb6eb7..a8c4bd7c21 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -17,6 +17,7 @@ from typing import Any, Dict, List, Set, Optional, Tuple, Union from urllib.parse import urlparse import boto3 +from botocore.exceptions import ClientError from packaging.version import Version import sagemaker from sagemaker.config.config_schema import ( @@ -41,10 +42,11 @@ JumpStartModelHeader, JumpStartModelSpecs, JumpStartVersionedModelId, + DeploymentConfigMetadata, ) from sagemaker.session import Session from sagemaker.config import load_sagemaker_config -from sagemaker.utils import resolve_value_from_config, TagsDict +from sagemaker.utils import resolve_value_from_config, TagsDict, get_instance_rate_per_hour from sagemaker.workflow import is_pipeline_variable @@ -1030,60 +1032,110 @@ def get_jumpstart_configs( ) -def get_metrics_from_deployment_configs( - deployment_configs: List[Dict[str, Any]], config_name: str -) -> Dict[str, List[str]]: - """Extracts metrics from deployment configs. +def add_instance_rate_stats_to_benchmark_metrics( + region: str, + benchmark_metrics: Optional[Dict[str, List[JumpStartBenchmarkStat]]], +) -> Optional[Tuple[str, Dict[str, List[JumpStartBenchmarkStat]]]]: + """Adds instance types metric stats to the given benchmark_metrics dict. Args: - deployment_configs (list[dict[str, Any]]): List of deployment configs. - config_name (str): The name of the deployment config use by the model. + region (str): AWS region. + benchmark_metrics (Dict[str, List[JumpStartBenchmarkStat]]): + Returns: + Tuple[str, Dict[str, List[JumpStartBenchmarkStat]]]: + Contains Error message and metrics dict. """ - data = {"Config Name": [], "Instance Type": [], "Selected": [], "Accelerated": []} + if benchmark_metrics is None: + return None + + final_benchmark_metrics = {} - for index, deployment_config in enumerate(deployment_configs): - if deployment_config.get("DeploymentArgs") is None: - continue + err_message = None + for instance_type, benchmark_metric_stats in benchmark_metrics.items(): + instance_type = instance_type if instance_type.startswith("ml.") else f"ml.{instance_type}" - benchmark_metrics = deployment_config.get("BenchmarkMetrics") - if benchmark_metrics is not None: - data["Config Name"].append(deployment_config.get("DeploymentConfigName")) - data["Instance Type"].append( - deployment_config.get("DeploymentArgs").get("InstanceType") - ) - data["Selected"].append( - "Yes" - if ( - config_name is not None - and config_name == deployment_config.get("DeploymentConfigName") + if not has_instance_rate_stat(benchmark_metric_stats): + try: + instance_type_rate = get_instance_rate_per_hour( + instance_type=instance_type, region=region ) - else "No" - ) - accelerated_configs = deployment_config.get("AccelerationConfigs") - if accelerated_configs is None: - data["Accelerated"].append("No") - else: - data["Accelerated"].append( - "Yes" - if ( - len(accelerated_configs) > 0 - and accelerated_configs[0].get("Enabled", False) - ) - else "No" + benchmark_metric_stats.append(JumpStartBenchmarkStat(instance_type_rate)) + final_benchmark_metrics[instance_type] = benchmark_metric_stats + + except ClientError as e: + final_benchmark_metrics[instance_type] = benchmark_metric_stats + err_message = e.response["Error"]["Message"] + except Exception: # pylint: disable=W0703 + final_benchmark_metrics[instance_type] = benchmark_metric_stats + err_message = ( + f"Unable to get instance rate per hour for instance type: {instance_type}." ) - if index == 0: - for benchmark_metric in benchmark_metrics: - column_name = f"{benchmark_metric.get('name')} ({benchmark_metric.get('unit')})" - data[column_name] = [] + return err_message, final_benchmark_metrics + + +def has_instance_rate_stat(benchmark_metric_stats: Optional[List[JumpStartBenchmarkStat]]) -> bool: + """Determines whether a benchmark metric stats contains instance rate metric stat. + + Args: + benchmark_metric_stats (Optional[List[JumpStartBenchmarkStat]]): + List of benchmark metric stats. + Returns: + bool: Whether the benchmark metric stats contains instance rate metric stat. + """ + if benchmark_metric_stats is None: + return False + + for benchmark_metric_stat in benchmark_metric_stats: + if benchmark_metric_stat.name.lower() == "instance rate": + return True + + return False + - for benchmark_metric in benchmark_metrics: - column_name = f"{benchmark_metric.get('name')} ({benchmark_metric.get('unit')})" - if column_name in data.keys(): - data[column_name].append(benchmark_metric.get("value")) +def get_metrics_from_deployment_configs( + deployment_configs: List[DeploymentConfigMetadata], +) -> Dict[str, List[str]]: + """Extracts benchmark metrics from deployment configs metadata. - if "Yes" not in data["Accelerated"]: - del data["Accelerated"] + Args: + deployment_configs (List[DeploymentConfigMetadata]): List of deployment configs metadata. + """ + data = {"Config Name": [], "Instance Type": []} + + for outer_index, deployment_config in enumerate(deployment_configs): + if deployment_config.deployment_args is None: + continue + + benchmark_metrics = deployment_config.benchmark_metrics + if benchmark_metrics is None: + continue + + for inner_index, current_instance_type in enumerate(benchmark_metrics): + current_instance_type_metrics = benchmark_metrics[current_instance_type] + + data["Config Name"].append(deployment_config.deployment_config_name) + instance_type_to_display = ( + f"{current_instance_type} (Default)" + if current_instance_type == deployment_config.deployment_args.default_instance_type + else current_instance_type + ) + data["Instance Type"].append(instance_type_to_display) + + if outer_index == 0 and inner_index == 0: + temp_data = {} + for metric in current_instance_type_metrics: + column_name = f"{metric.name.replace('_', ' ').title()} ({metric.unit})" + if metric.name.lower() == "instance rate": + data[column_name] = [] + else: + temp_data[column_name] = [] + data = {**data, **temp_data} + + for metric in current_instance_type_metrics: + column_name = f"{metric.name.replace('_', ' ').title()} ({metric.unit})" + if column_name in data: + data[column_name].append(metric.value) return data diff --git a/src/sagemaker/serve/builder/jumpstart_builder.py b/src/sagemaker/serve/builder/jumpstart_builder.py index d3c2581885..ec987dd9fe 100644 --- a/src/sagemaker/serve/builder/jumpstart_builder.py +++ b/src/sagemaker/serve/builder/jumpstart_builder.py @@ -431,18 +431,21 @@ def tune_for_tgi_jumpstart(self, max_tuning_duration: int = 1800): sharded_supported=sharded_supported, max_tuning_duration=max_tuning_duration ) - def set_deployment_config(self, config_name: Optional[str]) -> None: + def set_deployment_config(self, config_name: str, instance_type: str) -> None: """Sets the deployment config to apply to the model. Args: - config_name (Optional[str]): - The name of the deployment config. Set to None to unset - any existing config that is applied to the model. + config_name (str): + The name of the deployment config to apply to the model. + Call list_deployment_configs to see the list of config names. + instance_type (str): + The instance_type that the model will use after setting + the config. """ if not hasattr(self, "pysdk_model") or self.pysdk_model is None: raise Exception("Cannot set deployment config to an uninitialized model.") - self.pysdk_model.set_deployment_config(config_name) + self.pysdk_model.set_deployment_config(config_name, instance_type) def get_deployment_config(self) -> Optional[Dict[str, Any]]: """Gets the deployment config to apply to the model. diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index 35f60b37e1..6c9e1b4b16 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -1664,17 +1664,21 @@ def deep_override_dict( def get_instance_rate_per_hour( instance_type: str, region: str, -) -> Union[Dict[str, str], None]: +) -> Optional[Dict[str, str]]: """Gets instance rate per hour for the given instance type. Args: instance_type (str): The instance type. region (str): The region. Returns: - Union[Dict[str, str], None]: Instance rate per hour. - Example: {'name': 'Instance Rate', 'unit': 'USD/Hrs', 'value': '1.1250000000'}}. - """ + Optional[Dict[str, str]]: Instance rate per hour. + Example: {'name': 'Instance Rate', 'unit': 'USD/Hrs', 'value': '1.125'}. + Raises: + Exception: An exception is raised if + the IAM role is not authorized to perform pricing:GetProducts. + or unexpected event happened. + """ region_name = "us-east-1" if region.startswith("eu") or region.startswith("af"): region_name = "eu-central-1" @@ -1682,35 +1686,34 @@ def get_instance_rate_per_hour( region_name = "ap-south-1" pricing_client: boto3.client = boto3.client("pricing", region_name=region_name) - try: - res = pricing_client.get_products( - ServiceCode="AmazonSageMaker", - Filters=[ - {"Type": "TERM_MATCH", "Field": "instanceName", "Value": instance_type}, - {"Type": "TERM_MATCH", "Field": "locationType", "Value": "AWS Region"}, - {"Type": "TERM_MATCH", "Field": "regionCode", "Value": region}, - ], - ) + res = pricing_client.get_products( + ServiceCode="AmazonSageMaker", + Filters=[ + {"Type": "TERM_MATCH", "Field": "instanceName", "Value": instance_type}, + {"Type": "TERM_MATCH", "Field": "locationType", "Value": "AWS Region"}, + {"Type": "TERM_MATCH", "Field": "regionCode", "Value": region}, + ], + ) - price_list = res.get("PriceList", []) - if len(price_list) > 0: - price_data = price_list[0] - if isinstance(price_data, str): - price_data = json.loads(price_data) + price_list = res.get("PriceList", []) + if len(price_list) > 0: + price_data = price_list[0] + if isinstance(price_data, str): + price_data = json.loads(price_data) - return extract_instance_rate_per_hour(price_data) - except Exception as e: # pylint: disable=W0703 - logging.exception("Error getting instance rate: %s", e) - return None + instance_rate_per_hour = extract_instance_rate_per_hour(price_data) + if instance_rate_per_hour is not None: + return instance_rate_per_hour + raise Exception(f"Unable to get instance rate per hour for instance type: {instance_type}.") -def extract_instance_rate_per_hour(price_data: Dict[str, Any]) -> Union[Dict[str, str], None]: +def extract_instance_rate_per_hour(price_data: Dict[str, Any]) -> Optional[Dict[str, str]]: """Extract instance rate per hour for the given Price JSON data. Args: price_data (Dict[str, Any]): The Price JSON data. Returns: - Union[Dict[str, str], None]: Instance rate per hour. + Optional[Dict[str, str], None]: Instance rate per hour. """ if price_data is not None: @@ -1718,9 +1721,12 @@ def extract_instance_rate_per_hour(price_data: Dict[str, Any]) -> Union[Dict[str for dimension in price_dimensions: for price in dimension.get("priceDimensions", {}).values(): for currency in price.get("pricePerUnit", {}).keys(): + value = price.get("pricePerUnit", {}).get(currency) + if value is not None: + value = str(round(float(value), 3)) return { "unit": f"{currency}/{price.get('unit', 'Hrs')}", - "value": price.get("pricePerUnit", {}).get(currency), + "value": value, "name": "Instance Rate", } return None diff --git a/tests/unit/sagemaker/jumpstart/model/test_model.py b/tests/unit/sagemaker/jumpstart/model/test_model.py index 5bbc31a5b1..cd11d950d5 100644 --- a/tests/unit/sagemaker/jumpstart/model/test_model.py +++ b/tests/unit/sagemaker/jumpstart/model/test_model.py @@ -15,6 +15,7 @@ from typing import Optional, Set from unittest import mock import unittest + import pandas as pd from mock import MagicMock, Mock import pytest @@ -52,6 +53,7 @@ get_mock_init_kwargs, get_base_deployment_configs, get_base_spec_with_prototype_configs_with_missing_benchmarks, + append_instance_stat_metrics, ) import boto3 @@ -66,7 +68,6 @@ class ModelTest(unittest.TestCase): - mock_session_empty_config = MagicMock(sagemaker_config={}) @mock.patch( @@ -1714,19 +1715,17 @@ def test_model_set_deployment_config_incompatible_instance_type_or_name( @mock.patch("sagemaker.jumpstart.model.get_init_kwargs") @mock.patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs") - @mock.patch("sagemaker.jumpstart.model.get_instance_rate_per_hour") + @mock.patch("sagemaker.jumpstart.model.add_instance_rate_stats_to_benchmark_metrics") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") - @mock.patch("sagemaker.jumpstart.model.Model.deploy") @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) def test_model_list_deployment_configs( self, - mock_model_deploy: mock.Mock, mock_get_model_specs: mock.Mock, mock_session: mock.Mock, mock_get_manifest: mock.Mock, - mock_get_instance_rate_per_hour: mock.Mock, + mock_add_instance_rate_stats_to_benchmark_metrics: mock.Mock, mock_verify_model_region_and_return_specs: mock.Mock, mock_get_init_kwargs: mock.Mock, ): @@ -1736,16 +1735,14 @@ def test_model_list_deployment_configs( mock_verify_model_region_and_return_specs.side_effect = ( lambda *args, **kwargs: get_base_spec_with_prototype_configs() ) - mock_get_instance_rate_per_hour.side_effect = lambda *args, **kwargs: { - "name": "Instance Rate", - "unit": "USD/Hrs", - "value": "0.0083000000", - } + mock_add_instance_rate_stats_to_benchmark_metrics.side_effect = lambda region, metrics: ( + None, + append_instance_stat_metrics(metrics), + ) mock_get_model_specs.side_effect = get_prototype_spec_with_configs mock_get_manifest.side_effect = ( lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) ) - mock_model_deploy.return_value = default_predictor mock_session.return_value = sagemaker_session @@ -1756,19 +1753,15 @@ def test_model_list_deployment_configs( self.assertEqual(configs, get_base_deployment_configs()) @mock.patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs") - @mock.patch("sagemaker.jumpstart.model.get_instance_rate_per_hour") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") - @mock.patch("sagemaker.jumpstart.model.Model.deploy") @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) def test_model_list_deployment_configs_empty( self, - mock_model_deploy: mock.Mock, mock_get_model_specs: mock.Mock, mock_session: mock.Mock, mock_get_manifest: mock.Mock, - mock_get_instance_rate_per_hour: mock.Mock, mock_verify_model_region_and_return_specs: mock.Mock, ): model_id, _ = "pytorch-eqa-bert-base-cased", "*" @@ -1776,16 +1769,10 @@ def test_model_list_deployment_configs_empty( mock_verify_model_region_and_return_specs.side_effect = ( lambda *args, **kwargs: get_special_model_spec(model_id="gemma-model") ) - mock_get_instance_rate_per_hour.side_effect = lambda *args, **kwargs: { - "name": "Instance Rate", - "unit": "USD/Hrs", - "value": "0.0083000000", - } mock_get_model_specs.side_effect = get_prototype_spec_with_configs mock_get_manifest.side_effect = ( lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) ) - mock_model_deploy.return_value = default_predictor mock_session.return_value = sagemaker_session @@ -1797,7 +1784,7 @@ def test_model_list_deployment_configs_empty( @mock.patch("sagemaker.jumpstart.model.get_init_kwargs") @mock.patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs") - @mock.patch("sagemaker.jumpstart.model.get_instance_rate_per_hour") + @mock.patch("sagemaker.jumpstart.model.add_instance_rate_stats_to_benchmark_metrics") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @@ -1809,7 +1796,7 @@ def test_model_retrieve_deployment_config( mock_get_model_specs: mock.Mock, mock_session: mock.Mock, mock_get_manifest: mock.Mock, - mock_get_instance_rate_per_hour: mock.Mock, + mock_add_instance_rate_stats_to_benchmark_metrics: mock.Mock, mock_verify_model_region_and_return_specs: mock.Mock, mock_get_init_kwargs: mock.Mock, ): @@ -1818,18 +1805,17 @@ def test_model_retrieve_deployment_config( mock_verify_model_region_and_return_specs.side_effect = ( lambda *args, **kwargs: get_base_spec_with_prototype_configs_with_missing_benchmarks() ) - mock_get_instance_rate_per_hour.side_effect = lambda *args, **kwargs: { - "name": "Instance Rate", - "unit": "USD/Hrs", - "value": "0.0083000000", - } + mock_add_instance_rate_stats_to_benchmark_metrics.side_effect = lambda region, metrics: ( + None, + append_instance_stat_metrics(metrics), + ) mock_get_model_specs.side_effect = get_prototype_spec_with_configs mock_get_manifest.side_effect = ( lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) ) mock_model_deploy.return_value = default_predictor - expected = get_base_deployment_configs()[0] + expected = get_base_deployment_configs(True)[0] config_name = expected.get("DeploymentConfigName") instance_type = expected.get("InstanceType") mock_get_init_kwargs.side_effect = lambda *args, **kwargs: get_mock_init_kwargs( @@ -1849,19 +1835,17 @@ def test_model_retrieve_deployment_config( @mock.patch("sagemaker.jumpstart.model.get_init_kwargs") @mock.patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs") - @mock.patch("sagemaker.jumpstart.model.get_instance_rate_per_hour") + @mock.patch("sagemaker.jumpstart.model.add_instance_rate_stats_to_benchmark_metrics") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") - @mock.patch("sagemaker.jumpstart.model.Model.deploy") @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) def test_model_display_benchmark_metrics( self, - mock_model_deploy: mock.Mock, mock_get_model_specs: mock.Mock, mock_session: mock.Mock, mock_get_manifest: mock.Mock, - mock_get_instance_rate_per_hour: mock.Mock, + mock_add_instance_rate_stats_to_benchmark_metrics: mock.Mock, mock_verify_model_region_and_return_specs: mock.Mock, mock_get_init_kwargs: mock.Mock, ): @@ -1871,16 +1855,14 @@ def test_model_display_benchmark_metrics( mock_verify_model_region_and_return_specs.side_effect = ( lambda *args, **kwargs: get_base_spec_with_prototype_configs_with_missing_benchmarks() ) - mock_get_instance_rate_per_hour.side_effect = lambda *args, **kwargs: { - "name": "Instance Rate", - "unit": "USD/Hrs", - "value": "0.0083000000", - } + mock_add_instance_rate_stats_to_benchmark_metrics.side_effect = lambda region, metrics: ( + None, + append_instance_stat_metrics(metrics), + ) mock_get_model_specs.side_effect = get_prototype_spec_with_configs mock_get_manifest.side_effect = ( lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) ) - mock_model_deploy.return_value = default_predictor mock_session.return_value = sagemaker_session @@ -1890,19 +1872,17 @@ def test_model_display_benchmark_metrics( @mock.patch("sagemaker.jumpstart.model.get_init_kwargs") @mock.patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs") - @mock.patch("sagemaker.jumpstart.model.get_instance_rate_per_hour") + @mock.patch("sagemaker.jumpstart.model.add_instance_rate_stats_to_benchmark_metrics") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") - @mock.patch("sagemaker.jumpstart.model.Model.deploy") @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) def test_model_benchmark_metrics( self, - mock_model_deploy: mock.Mock, mock_get_model_specs: mock.Mock, mock_session: mock.Mock, mock_get_manifest: mock.Mock, - mock_get_instance_rate_per_hour: mock.Mock, + mock_add_instance_rate_stats_to_benchmark_metrics: mock.Mock, mock_verify_model_region_and_return_specs: mock.Mock, mock_get_init_kwargs: mock.Mock, ): @@ -1912,16 +1892,14 @@ def test_model_benchmark_metrics( mock_verify_model_region_and_return_specs.side_effect = ( lambda *args, **kwargs: get_base_spec_with_prototype_configs() ) - mock_get_instance_rate_per_hour.side_effect = lambda *args, **kwargs: { - "name": "Instance Rate", - "unit": "USD/Hrs", - "value": "0.0083000000", - } + mock_add_instance_rate_stats_to_benchmark_metrics.side_effect = lambda region, metrics: ( + None, + append_instance_stat_metrics(metrics), + ) mock_get_model_specs.side_effect = get_prototype_spec_with_configs mock_get_manifest.side_effect = ( lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) ) - mock_model_deploy.return_value = default_predictor mock_session.return_value = sagemaker_session diff --git a/tests/unit/sagemaker/jumpstart/test_types.py b/tests/unit/sagemaker/jumpstart/test_types.py index 5ca01c3c52..c52bf76f4e 100644 --- a/tests/unit/sagemaker/jumpstart/test_types.py +++ b/tests/unit/sagemaker/jumpstart/test_types.py @@ -22,6 +22,8 @@ JumpStartModelSpecs, JumpStartModelHeader, JumpStartConfigComponent, + DeploymentConfigMetadata, + JumpStartModelInitKwargs, ) from tests.unit.sagemaker.jumpstart.constants import ( BASE_SPEC, @@ -29,6 +31,7 @@ INFERENCE_CONFIGS, TRAINING_CONFIG_RANKINGS, TRAINING_CONFIGS, + INIT_KWARGS, ) INSTANCE_TYPE_VARIANT = JumpStartInstanceTypeVariants( @@ -1248,3 +1251,39 @@ def test_set_training_config(): with pytest.raises(ValueError) as error: specs1.set_config("invalid_name", scope="unknown scope") + + +def test_deployment_config_metadata(): + spec = {**BASE_SPEC, **INFERENCE_CONFIGS, **INFERENCE_CONFIG_RANKINGS} + specs = JumpStartModelSpecs(spec) + jumpstart_config = specs.inference_configs.get_top_config_from_ranking() + + deployment_config_metadata = DeploymentConfigMetadata( + jumpstart_config.config_name, + jumpstart_config.benchmark_metrics, + jumpstart_config.resolved_config, + JumpStartModelInitKwargs( + model_id=specs.model_id, + model_data=INIT_KWARGS.get("model_data"), + image_uri=INIT_KWARGS.get("image_uri"), + instance_type=INIT_KWARGS.get("instance_type"), + env=INIT_KWARGS.get("env"), + config_name=jumpstart_config.config_name, + ), + ) + + json_obj = deployment_config_metadata.to_json() + + assert isinstance(json_obj, dict) + assert json_obj["DeploymentConfigName"] == jumpstart_config.config_name + for key in json_obj["BenchmarkMetrics"]: + assert len(json_obj["BenchmarkMetrics"][key]) == len( + jumpstart_config.benchmark_metrics.get(key) + ) + assert json_obj["AccelerationConfigs"] == jumpstart_config.resolved_config.get( + "acceleration_configs" + ) + assert json_obj["DeploymentArgs"]["ImageUri"] == INIT_KWARGS.get("image_uri") + assert json_obj["DeploymentArgs"]["ModelData"] == INIT_KWARGS.get("model_data") + assert json_obj["DeploymentArgs"]["Environment"] == INIT_KWARGS.get("env") + assert json_obj["DeploymentArgs"]["InstanceType"] == INIT_KWARGS.get("instance_type") diff --git a/tests/unit/sagemaker/jumpstart/test_utils.py b/tests/unit/sagemaker/jumpstart/test_utils.py index f576e36185..f7458a29e9 100644 --- a/tests/unit/sagemaker/jumpstart/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_utils.py @@ -13,6 +13,8 @@ from __future__ import absolute_import import os from unittest import TestCase + +from botocore.exceptions import ClientError from mock.mock import Mock, patch import pytest import boto3 @@ -49,8 +51,7 @@ get_spec_from_base_spec, get_special_model_spec, get_prototype_manifest, - get_base_deployment_configs, - get_base_deployment_configs_with_acceleration_configs, + get_base_deployment_configs_metadata, ) from mock import MagicMock @@ -1763,53 +1764,103 @@ def test_get_jumpstart_benchmark_stats_training( } +def test_extract_metrics_from_deployment_configs(): + configs = get_base_deployment_configs_metadata() + configs[0].benchmark_metrics = None + configs[2].deployment_args = None + + data = utils.get_metrics_from_deployment_configs(configs) + + for key in data: + assert len(data[key]) == (len(configs) - 2) + + +@patch("sagemaker.jumpstart.utils.get_instance_rate_per_hour") +def test_add_instance_rate_stats_to_benchmark_metrics( + mock_get_instance_rate_per_hour, +): + mock_get_instance_rate_per_hour.side_effect = lambda *args, **kwargs: { + "name": "Instance Rate", + "unit": "USD/Hrs", + "value": "3.76", + } + + err, out = utils.add_instance_rate_stats_to_benchmark_metrics( + "us-west-2", + { + "ml.p2.xlarge": [ + JumpStartBenchmarkStat({"name": "Latency", "value": "100", "unit": "Tokens/S"}) + ], + "ml.gd4.xlarge": [ + JumpStartBenchmarkStat({"name": "Latency", "value": "100", "unit": "Tokens/S"}) + ], + }, + ) + + assert err is None + for key in out: + assert len(out[key]) == 2 + for metric in out[key]: + if metric.name == "Instance Rate": + assert metric.to_json() == { + "name": "Instance Rate", + "unit": "USD/Hrs", + "value": "3.76", + } + + +@patch("sagemaker.jumpstart.utils.get_instance_rate_per_hour") +def test_add_instance_rate_stats_to_benchmark_metrics_client_ex( + mock_get_instance_rate_per_hour, +): + mock_get_instance_rate_per_hour.side_effect = ClientError( + {"Error": {"Message": "is not authorized to perform: pricing:GetProducts"}}, "GetProducts" + ) + + err, out = utils.add_instance_rate_stats_to_benchmark_metrics( + "us-west-2", + { + "ml.p2.xlarge": [ + JumpStartBenchmarkStat({"name": "Latency", "value": "100", "unit": "Tokens/S"}) + ], + }, + ) + + assert err == "is not authorized to perform: pricing:GetProducts" + for key in out: + assert len(out[key]) == 1 + + +@patch("sagemaker.jumpstart.utils.get_instance_rate_per_hour") +def test_add_instance_rate_stats_to_benchmark_metrics_ex( + mock_get_instance_rate_per_hour, +): + mock_get_instance_rate_per_hour.side_effect = Exception() + + err, out = utils.add_instance_rate_stats_to_benchmark_metrics( + "us-west-2", + { + "ml.p2.xlarge": [ + JumpStartBenchmarkStat({"name": "Latency", "value": "100", "unit": "Tokens/S"}) + ], + }, + ) + + assert err == "Unable to get instance rate per hour for instance type: ml.p2.xlarge." + for key in out: + assert len(out[key]) == 1 + + @pytest.mark.parametrize( - "config_name, configs, expected", + "stats, expected", [ + (None, False), ( - None, - get_base_deployment_configs(), - { - "Config Name": [ - "neuron-inference", - "neuron-inference-budget", - "gpu-inference-budget", - "gpu-inference", - ], - "Instance Type": ["ml.p2.xlarge", "ml.p2.xlarge", "ml.p2.xlarge", "ml.p2.xlarge"], - "Selected": ["No", "No", "No", "No"], - "Instance Rate (USD/Hrs)": [ - "0.0083000000", - "0.0083000000", - "0.0083000000", - "0.0083000000", - ], - }, - ), - ( - "neuron-inference", - get_base_deployment_configs_with_acceleration_configs(), - { - "Config Name": [ - "neuron-inference", - "neuron-inference-budget", - "gpu-inference-budget", - "gpu-inference", - ], - "Instance Type": ["ml.p2.xlarge", "ml.p2.xlarge", "ml.p2.xlarge", "ml.p2.xlarge"], - "Selected": ["Yes", "No", "No", "No"], - "Accelerated": ["Yes", "No", "No", "No"], - "Instance Rate (USD/Hrs)": [ - "0.0083000000", - "0.0083000000", - "0.0083000000", - "0.0083000000", - ], - }, + [JumpStartBenchmarkStat({"name": "Instance Rate", "unit": "USD/Hrs", "value": "3.76"})], + True, ), + ([JumpStartBenchmarkStat({"name": "Latency", "value": "100", "unit": "Tokens/S"})], False), ], ) -def test_extract_metrics_from_deployment_configs(config_name, configs, expected): - data = utils.get_metrics_from_deployment_configs(configs, config_name) - - assert data == expected +def test_has_instance_rate_stat(stats, expected): + assert utils.has_instance_rate_stat(stats) is expected diff --git a/tests/unit/sagemaker/jumpstart/utils.py b/tests/unit/sagemaker/jumpstart/utils.py index 8b814c3d71..e8a93dff6c 100644 --- a/tests/unit/sagemaker/jumpstart/utils.py +++ b/tests/unit/sagemaker/jumpstart/utils.py @@ -29,6 +29,9 @@ JumpStartS3FileType, JumpStartModelHeader, JumpStartModelInitKwargs, + DeploymentConfigMetadata, + JumpStartModelDeployKwargs, + JumpStartBenchmarkStat, ) from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.jumpstart.utils import get_formatted_manifest @@ -323,10 +326,6 @@ def overwrite_dictionary( return base_dictionary -def get_base_deployment_configs() -> List[Dict[str, Any]]: - return DEPLOYMENT_CONFIGS - - def get_base_deployment_configs_with_acceleration_configs() -> List[Dict[str, Any]]: configs = copy.deepcopy(DEPLOYMENT_CONFIGS) configs[0]["AccelerationConfigs"] = [ @@ -348,3 +347,60 @@ def get_mock_init_kwargs( resources=ResourceRequirements(), config_name=config_name, ) + + +def get_base_deployment_configs_metadata( + omit_benchmark_metrics: bool = False, +) -> List[DeploymentConfigMetadata]: + specs = ( + get_base_spec_with_prototype_configs_with_missing_benchmarks() + if omit_benchmark_metrics + else get_base_spec_with_prototype_configs() + ) + configs = [] + for config_name, jumpstart_config in specs.inference_configs.configs.items(): + benchmark_metrics = jumpstart_config.benchmark_metrics + + if benchmark_metrics: + for instance_type in benchmark_metrics: + benchmark_metrics[instance_type].append( + JumpStartBenchmarkStat( + {"name": "Instance Rate", "unit": "USD/Hrs", "value": "3.76"} + ) + ) + + configs.append( + DeploymentConfigMetadata( + config_name=config_name, + benchmark_metrics=jumpstart_config.benchmark_metrics, + resolved_config=jumpstart_config.resolved_config, + init_kwargs=get_mock_init_kwargs( + get_base_spec_with_prototype_configs().model_id, config_name + ), + deploy_kwargs=JumpStartModelDeployKwargs( + model_id=get_base_spec_with_prototype_configs().model_id, + ), + ) + ) + return configs + + +def get_base_deployment_configs( + omit_benchmark_metrics: bool = False, +) -> List[Dict[str, Any]]: + return [ + config.to_json() for config in get_base_deployment_configs_metadata(omit_benchmark_metrics) + ] + + +def append_instance_stat_metrics( + metrics: Dict[str, List[JumpStartBenchmarkStat]] +) -> Dict[str, List[JumpStartBenchmarkStat]]: + if metrics is not None: + for key in metrics: + metrics[key].append( + JumpStartBenchmarkStat( + {"name": "Instance Rate", "value": "3.76", "unit": "USD/Hrs"} + ) + ) + return metrics diff --git a/tests/unit/sagemaker/serve/builder/test_js_builder.py b/tests/unit/sagemaker/serve/builder/test_js_builder.py index b83b113209..56b01cd9e3 100644 --- a/tests/unit/sagemaker/serve/builder/test_js_builder.py +++ b/tests/unit/sagemaker/serve/builder/test_js_builder.py @@ -752,9 +752,11 @@ def test_set_deployment_config( mock_pre_trained_model.return_value.image_uri = mock_tgi_image_uri builder.build() - builder.set_deployment_config("config-1") + builder.set_deployment_config("config-1", "ml.g5.24xlarge") - mock_pre_trained_model.return_value.set_deployment_config.assert_called_with("config-1") + mock_pre_trained_model.return_value.set_deployment_config.assert_called_with( + "config-1", "ml.g5.24xlarge" + ) @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None) @patch( @@ -789,7 +791,7 @@ def test_set_deployment_config_ex( "Cannot set deployment config to an uninitialized model.", lambda: ModelBuilder( model="facebook/galactica-mock-model-id", schema_builder=mock_schema_builder - ).set_deployment_config("config-2"), + ).set_deployment_config("config-2", "ml.g5.24xlarge"), ) @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None) diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index bf6a7cb09f..e94f3087ad 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -1871,43 +1871,103 @@ def test_deep_override_skip_keys(self): @pytest.mark.parametrize( - "instance, region", + "instance, region, amazon_sagemaker_price_result, expected", [ - ("t4g.nano", "us-west-2"), - ("t4g.nano", "eu-central-1"), - ("t4g.nano", "af-south-1"), - ("t4g.nano", "ap-northeast-2"), - ("t4g.nano", "cn-north-1"), + ( + "ml.t4g.nano", + "us-west-2", + { + "PriceList": [ + { + "terms": { + "OnDemand": { + "3WK7G7WSYVS3K492.JRTCKXETXF": { + "priceDimensions": { + "3WK7G7WSYVS3K492.JRTCKXETXF.6YS6EN2CT7": { + "unit": "Hrs", + "endRange": "Inf", + "description": "$0.9 per Unused Reservation Linux p2.xlarge Instance Hour", + "appliesTo": [], + "rateCode": "3WK7G7WSYVS3K492.JRTCKXETXF.6YS6EN2CT7", + "beginRange": "0", + "pricePerUnit": {"USD": "0.9000000000"}, + } + } + } + } + }, + } + ] + }, + {"name": "Instance Rate", "unit": "USD/Hrs", "value": "0.9"}, + ), + ( + "ml.t4g.nano", + "eu-central-1", + { + "PriceList": [ + '{"terms": {"OnDemand": {"22VNQ3N6GZGZMXYM.JRTCKXETXF": {"priceDimensions":{' + '"22VNQ3N6GZGZMXYM.JRTCKXETXF.6YS6EN2CT7": {"unit": "Hrs", "endRange": "Inf", "description": ' + '"$0.0083 per' + "On" + 'Demand Ubuntu Pro t4g.nano Instance Hour", "appliesTo": [], "rateCode": ' + '"22VNQ3N6GZGZMXYM.JRTCKXETXF.6YS6EN2CT7", "beginRange": "0", "pricePerUnit":{"USD": ' + '"0.0083000000"}}},' + '"sku": "22VNQ3N6GZGZMXYM", "effectiveDate": "2024-04-01T00:00:00Z", "offerTermCode": "JRTCKXETXF",' + '"termAttributes": {}}}}}' + ] + }, + {"name": "Instance Rate", "unit": "USD/Hrs", "value": "0.008"}, + ), + ( + "ml.t4g.nano", + "af-south-1", + { + "PriceList": [ + '{"terms": {"OnDemand": {"22VNQ3N6GZGZMXYM.JRTCKXETXF": {"priceDimensions":{' + '"22VNQ3N6GZGZMXYM.JRTCKXETXF.6YS6EN2CT7": {"unit": "Hrs", "endRange": "Inf", "description": ' + '"$0.0083 per' + "On" + 'Demand Ubuntu Pro t4g.nano Instance Hour", "appliesTo": [], "rateCode": ' + '"22VNQ3N6GZGZMXYM.JRTCKXETXF.6YS6EN2CT7", "beginRange": "0", "pricePerUnit":{"USD": ' + '"0.0083000000"}}},' + '"sku": "22VNQ3N6GZGZMXYM", "effectiveDate": "2024-04-01T00:00:00Z", "offerTermCode": "JRTCKXETXF",' + '"termAttributes": {}}}}}' + ] + }, + {"name": "Instance Rate", "unit": "USD/Hrs", "value": "0.008"}, + ), + ( + "ml.t4g.nano", + "ap-northeast-2", + { + "PriceList": [ + '{"terms": {"OnDemand": {"22VNQ3N6GZGZMXYM.JRTCKXETXF": {"priceDimensions":{' + '"22VNQ3N6GZGZMXYM.JRTCKXETXF.6YS6EN2CT7": {"unit": "Hrs", "endRange": "Inf", "description": ' + '"$0.0083 per' + "On" + 'Demand Ubuntu Pro t4g.nano Instance Hour", "appliesTo": [], "rateCode": ' + '"22VNQ3N6GZGZMXYM.JRTCKXETXF.6YS6EN2CT7", "beginRange": "0", "pricePerUnit":{"USD": ' + '"0.0083000000"}}},' + '"sku": "22VNQ3N6GZGZMXYM", "effectiveDate": "2024-04-01T00:00:00Z", "offerTermCode": "JRTCKXETXF",' + '"termAttributes": {}}}}}' + ] + }, + {"name": "Instance Rate", "unit": "USD/Hrs", "value": "0.008"}, + ), ], ) @patch("boto3.client") -def test_get_instance_rate_per_hour(mock_client, instance, region): - amazon_sagemaker_price_result = { - "PriceList": [ - '{"terms": {"OnDemand": {"22VNQ3N6GZGZMXYM.JRTCKXETXF": {"priceDimensions":{' - '"22VNQ3N6GZGZMXYM.JRTCKXETXF.6YS6EN2CT7": {"unit": "Hrs", "endRange": "Inf", "description": "$0.0083 per ' - "On" - 'Demand Ubuntu Pro t4g.nano Instance Hour", "appliesTo": [], "rateCode": ' - '"22VNQ3N6GZGZMXYM.JRTCKXETXF.6YS6EN2CT7", "beginRange": "0", "pricePerUnit":{"USD": "0.0083000000"}}}, ' - '"sku": "22VNQ3N6GZGZMXYM", "effectiveDate": "2024-04-01T00:00:00Z", "offerTermCode": "JRTCKXETXF", ' - '"termAttributes": {}}}}}' - ] - } +def test_get_instance_rate_per_hour( + mock_client, instance, region, amazon_sagemaker_price_result, expected +): mock_client.return_value.get_products.side_effect = ( lambda *args, **kwargs: amazon_sagemaker_price_result ) instance_rate = get_instance_rate_per_hour(instance_type=instance, region=region) - assert instance_rate == {"name": "Instance Rate", "unit": "USD/Hrs", "value": "0.0083000000"} - - -@patch("boto3.client") -def test_get_instance_rate_per_hour_ex(mock_client): - mock_client.return_value.get_products.side_effect = lambda *args, **kwargs: Exception() - instance_rate = get_instance_rate_per_hour(instance_type="ml.t4g.nano", region="us-west-2") - - assert instance_rate is None + assert instance_rate == expected @pytest.mark.parametrize( @@ -1934,7 +1994,7 @@ def test_get_instance_rate_per_hour_ex(mock_client): } } }, - {"name": "Instance Rate", "unit": "USD/Hrs", "value": "0.9000000000"}, + {"name": "Instance Rate", "unit": "USD/Hrs", "value": "0.9"}, ), ], )