Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: require config name and instance type in set_deployment_config #4625

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions src/sagemaker/jumpstart/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,16 +429,22 @@ def retrieve_example_payload(self) -> JumpStartSerializablePayload:
sagemaker_session=self.sagemaker_session,
)

def set_deployment_config(self, config_name: Optional[str]) -> None:
def set_deployment_config(self, config_name: str, instance_type: str) -> None:
Captainia marked this conversation as resolved.
Show resolved Hide resolved
"""Sets the deployment config to apply to the model.

Args:
config_name (Optional[str]):
config_name (str):
The name of the deployment config. Set to None to unset
any existing config that is applied to the model.
instance_type (str):
The instance_type that the model will use after setting
the config.
"""
self.__init__(
model_id=self.model_id, model_version=self.model_version, config_name=config_name
model_id=self.model_id,
model_version=self.model_version,
instance_type=instance_type,
config_name=config_name,
)

@property
Expand Down
59 changes: 6 additions & 53 deletions tests/unit/sagemaker/jumpstart/model/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1614,15 +1614,15 @@ def test_model_set_deployment_config(
mock_get_model_specs.reset_mock()
mock_model_deploy.reset_mock()
mock_get_model_specs.side_effect = get_prototype_spec_with_configs
model.set_deployment_config("neuron-inference")
model.set_deployment_config("neuron-inference", "ml.inf2.2xlarge")

assert model.config_name == "neuron-inference"

model.deploy()

mock_model_deploy.assert_called_once_with(
initial_instance_count=1,
instance_type="ml.inf2.xlarge",
instance_type="ml.inf2.2xlarge",
tags=[
{"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"},
{"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"},
Expand All @@ -1631,34 +1631,8 @@ def test_model_set_deployment_config(
wait=True,
endpoint_logging=False,
)

@mock.patch(
"sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {}
)
@mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest")
@mock.patch("sagemaker.jumpstart.factory.model.Session")
@mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
@mock.patch("sagemaker.jumpstart.model.Model.deploy")
@mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region)
def test_model_unset_deployment_config(
self,
mock_model_deploy: mock.Mock,
mock_get_model_specs: mock.Mock,
mock_session: mock.Mock,
mock_get_manifest: mock.Mock,
mock_get_jumpstart_configs: mock.Mock,
):
mock_get_model_specs.side_effect = get_prototype_spec_with_configs
mock_get_manifest.side_effect = (
lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type)
)
mock_model_deploy.return_value = default_predictor

model_id, _ = "pytorch-eqa-bert-base-cased", "*"

mock_session.return_value = sagemaker_session

model = JumpStartModel(model_id=model_id, config_name="neuron-inference")
mock_model_deploy.reset_mock()
model.set_deployment_config("neuron-inference", "ml.inf2.xlarge")

assert model.config_name == "neuron-inference"

Expand All @@ -1676,24 +1650,6 @@ def test_model_unset_deployment_config(
endpoint_logging=False,
)

mock_get_model_specs.reset_mock()
mock_model_deploy.reset_mock()
mock_get_model_specs.side_effect = get_prototype_model_spec
model.set_deployment_config(None)

model.deploy()

mock_model_deploy.assert_called_once_with(
initial_instance_count=1,
instance_type="ml.p2.xlarge",
tags=[
{"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"},
{"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"},
],
wait=True,
endpoint_logging=False,
)

@mock.patch("sagemaker.jumpstart.model.get_init_kwargs")
@mock.patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs")
@mock.patch("sagemaker.jumpstart.model.get_instance_rate_per_hour")
Expand Down Expand Up @@ -1813,6 +1769,7 @@ def test_model_retrieve_deployment_config(

expected = get_base_deployment_configs()[0]
config_name = expected.get("DeploymentConfigName")
instance_type = expected.get("InstanceType")
mock_get_init_kwargs.side_effect = lambda *args, **kwargs: get_mock_init_kwargs(
model_id, config_name
)
Expand All @@ -1821,17 +1778,13 @@ def test_model_retrieve_deployment_config(

model = JumpStartModel(model_id=model_id)

model.set_deployment_config(config_name)
model.set_deployment_config(config_name, instance_type)

self.assertEqual(model.deployment_config, expected)

mock_get_init_kwargs.reset_mock()
mock_get_init_kwargs.side_effect = lambda *args, **kwargs: get_mock_init_kwargs(model_id)

# Unset
model.set_deployment_config(None)
self.assertIsNone(model.deployment_config)

@mock.patch("sagemaker.jumpstart.model.get_init_kwargs")
@mock.patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs")
@mock.patch("sagemaker.jumpstart.model.get_instance_rate_per_hour")
Expand Down
Loading