Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deployment Configs - Follow-ups #4626

Merged
Prev Previous commit
Next Next commit
Testing
Jonathan Makunga committed Apr 27, 2024
commit 1d678ac1bf15a659fb7e2b27ebcc2b82fafc542a
21 changes: 15 additions & 6 deletions src/sagemaker/jumpstart/model.py
Original file line number Diff line number Diff line change
@@ -449,7 +449,10 @@ def deployment_config(self) -> Optional[Dict[str, Any]]:
Returns:
Optional[Dict[str, Any]]: Deployment config that will be applied to the model.
"""
return self._retrieve_selected_deployment_config(self.config_name, self.instance_type)
deployment_config = self._retrieve_selected_deployment_config(
self.config_name, self.instance_type
)
return deployment_config.to_json() if deployment_config is not None else None
makungaj1 marked this conversation as resolved.
Show resolved Hide resolved

@property
def benchmark_metrics(self) -> pd.DataFrame:
@@ -470,7 +473,11 @@ def list_deployment_configs(self) -> List[Dict[str, Any]]:
Returns:
List[Dict[str, Any]]: A list of deployment configs.
"""
return self._get_deployment_configs(self.instance_type)
# Temp
return [
deployment_config.to_json()
for deployment_config in self._get_deployment_configs(self.instance_type)
]

def _create_sagemaker_model(
self,
@@ -871,7 +878,7 @@ def _get_benchmarks_data(self) -> Dict[str, Any]:
@lru_cache
def _retrieve_selected_deployment_config(
self, config_name: str, instance_type: str
) -> Optional[Dict[str, Any]]:
) -> Optional[DeploymentConfigMetadata]:
"""Retrieve the deployment config to apply to the model.

Args:
@@ -883,12 +890,14 @@ def _retrieve_selected_deployment_config(
return None

for deployment_config in self._get_deployment_configs(instance_type):
if deployment_config.get("DeploymentConfigName") == config_name:
if deployment_config.deployment_config_name == config_name:
return deployment_config
return None

@lru_cache
def _get_deployment_configs(self, selected_instance_type: str) -> List[Dict[str, Any]]:
def _get_deployment_configs(
self, selected_instance_type: str
) -> List[DeploymentConfigMetadata]:
"""Retrieve the deployment configs to apply to the model."""
deployment_configs = []

@@ -927,7 +936,7 @@ def _get_deployment_configs(self, selected_instance_type: str) -> List[Dict[str,
init_kwargs,
deploy_kwargs,
)
deployment_configs.append(deployment_config_metadata.to_json())
deployment_configs.append(deployment_config_metadata)

return deployment_configs

6 changes: 2 additions & 4 deletions src/sagemaker/jumpstart/types.py
Original file line number Diff line number Diff line change
@@ -2265,9 +2265,7 @@ class DeploymentArgs(BaseDeploymentConfigDataHolder):
"image_uri",
"model_data",
"environment",
"selected_instance_type",
"default_instance_type",
"supported_instance_types",
"instance_type",
"compute_resource_requirements",
"model_data_download_timeout",
"container_startup_health_check_timeout",
@@ -2283,7 +2281,7 @@ def __init__(
if init_kwargs is not None:
self.image_uri = init_kwargs.image_uri
self.model_data = init_kwargs.model_data
self.selected_instance_type = init_kwargs.instance_type
self.instance_type = init_kwargs.instance_type
self.default_instance_type = resolved_config.get("default_inference_instance_type")
self.supported_instance_types = resolved_config.get(
"supported_inference_instance_types"
22 changes: 12 additions & 10 deletions src/sagemaker/jumpstart/utils.py
Original file line number Diff line number Diff line change
@@ -41,6 +41,7 @@
JumpStartModelHeader,
JumpStartModelSpecs,
JumpStartVersionedModelId,
DeploymentConfigMetadata,
)
from sagemaker.session import Session
from sagemaker.config import load_sagemaker_config
@@ -1031,7 +1032,7 @@ def get_jumpstart_configs(


def get_metrics_from_deployment_configs(
deployment_configs: List[Dict[str, Any]]
deployment_configs: List[DeploymentConfigMetadata],
) -> Dict[str, List[str]]:
"""Extracts metrics from deployment configs.

@@ -1042,34 +1043,35 @@ def get_metrics_from_deployment_configs(
data = {"Config Name": [], "Instance Type": []}

for index, deployment_config in enumerate(deployment_configs):
if deployment_config.get("DeploymentArgs") is None:
if deployment_config.deployment_args is None:
continue

benchmark_metrics = deployment_config.get("BenchmarkMetrics")
benchmark_metrics = deployment_config.benchmark_metrics
for current_instance_type, current_instance_type_metrics in benchmark_metrics.items():

data["Config Name"].append(deployment_config.get("DeploymentConfigName"))
data["Config Name"].append(deployment_config.deployment_config_name)
instance_type_to_display = (
f"{current_instance_type} (Default)"
if current_instance_type
== deployment_config.get("DeploymentArgs").get("DefaultInstanceType")
if current_instance_type == deployment_config.deployment_args.default_instance_type
else current_instance_type
)
data["Instance Type"].append(instance_type_to_display)

if index == 0:
temp_data = {}
for metric in current_instance_type_metrics:
column_name = f"{metric.get('name')} ({metric.get('unit')})"
if metric.get("name").lower() == "instance rate":
column_name = f"{metric.name} ({metric.unit})"
if metric.name.lower() == "instance rate":
data[column_name] = []
else:
temp_data[column_name] = []
data = {**data, **temp_data}

for metric in current_instance_type_metrics:
column_name = f"{metric.get('name')} ({metric.get('unit')})"
column_name = f"{metric.name} ({metric.unit})"
if column_name in data:
data[column_name].append(metric.get('value'))
for _ in range(index):
data[column_name].append(" - ")
data[column_name].append(metric.value)

return data