Skip to content

Commit

Permalink
fix: register jumpstart models on model registry
Browse files Browse the repository at this point in the history
  • Loading branch information
selvask-aws committed Jun 25, 2024
1 parent 58bb448 commit b73a795
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 21 deletions.
1 change: 0 additions & 1 deletion src/sagemaker/jumpstart/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -761,7 +761,6 @@ def register(
source_uri: Optional[Union[str, PipelineVariable]] = None,
model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None,
accept_eula: Optional[bool] = None,

):
"""Creates a model package for creating SageMaker models or listing on Marketplace.
Expand Down
1 change: 0 additions & 1 deletion src/sagemaker/jumpstart/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2410,7 +2410,6 @@ class JumpStartModelRegisterKwargs(JumpStartKwargs):
"model_version",
"hub_arn",
"sagemaker_session",
"model_type",
}

def __init__(
Expand Down
31 changes: 16 additions & 15 deletions src/sagemaker/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,7 @@ def register(
source_uri: Optional[Union[str, PipelineVariable]] = None,
model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None,
accept_eula: Optional[bool] = None,
model_type: Optional[JumpStartModelType] = None,
):
"""Creates a model package for creating SageMaker models or listing on Marketplace.
Expand Down Expand Up @@ -517,7 +518,10 @@ def register(

if image_uri is not None:
self.image_uri = image_uri
if self.model_type is not JumpStartModelType.PROPRIETARY:
if model_type is JumpStartModelType.PROPRIETARY:
source_uri = self.model_package_arn
model_package_group_name = self.model_id
else:
if model_package_group_name is None and model_package_name is None:
# If model package group and model package name is not set
# then register to auto-generated model package group
Expand All @@ -533,23 +537,20 @@ def register(
data_input_configuration=data_input_configuration,
container_def=container_def,
)
else:
container_def = {
"Image": self.image_uri,
}
else:
container_def = {
"Image": self.image_uri,
}

if isinstance(self.model_data, dict):
raise ValueError(
"Un-versioned SageMaker Model Package currently cannot be "
"created with ModelDataSource."
)
if isinstance(self.model_data, dict):
raise ValueError(
"Un-versioned SageMaker Model Package currently cannot be "
"created with ModelDataSource."
)

if self.model_data is not None:
container_def["ModelDataUrl"] = self.model_data
if self.model_data is not None:
container_def["ModelDataUrl"] = self.model_data

if self.model_type is JumpStartModelType.PROPRIETARY:
source_uri = self.model_package_arn
model_package_group_name = self.model_id

model_pkg_args = sagemaker.get_model_package_args(
self.content_types,
Expand Down
8 changes: 4 additions & 4 deletions tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,7 @@ def test_proprietary_jumpstart_model(setup):

assert response is not None


@pytest.mark.skipif(
True,
reason="Only enable if test account is subscribed to the proprietary model",
Expand All @@ -309,7 +310,6 @@ def test_register_proprietary_jumpstart_model(setup):
sagemaker_session=get_sm_session(),
)
model_package = model.register()


predictor = model_package.deploy(
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}]
Expand All @@ -329,7 +329,7 @@ def test_register_proprietary_jumpstart_model(setup):
)
def test_register_gated_jumpstart_model(setup):

model_id="meta-textgenerationneuron-llama-2-7b"
model_id = "meta-textgenerationneuron-llama-2-7b"
model = JumpStartModel(
model_id=model_id,
model_version="1.1.0",
Expand All @@ -339,7 +339,8 @@ def test_register_gated_jumpstart_model(setup):
model_package = model.register(accept_eula=True)

predictor = model_package.deploy(
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}], accept_eula=True
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
accept_eula=True,
)
payload = {"prompt": "To be, or", "maxTokens": 4, "temperature": 0, "numResults": 1}

Expand All @@ -348,4 +349,3 @@ def test_register_gated_jumpstart_model(setup):
predictor.delete_predictor()

assert response is not None

4 changes: 4 additions & 0 deletions tests/unit/sagemaker/jumpstart/model/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,8 +513,11 @@ def test_proprietary_model_endpoint(
model.deploy()

mock_model_register.assert_called_once_with(
model_type=JumpStartModelType.PROPRIETARY,
content_types=["application/json"],
response_types=["application/json"],
model_package_group_name=model_id,
source_uri=model.model_package_arn
)

mock_model_deploy.assert_called_once_with(
Expand Down Expand Up @@ -1416,6 +1419,7 @@ def test_model_registry_accept_and_response_types(
model.register()

mock_model_register.assert_called_once_with(
model_type=JumpStartModelType.OPEN_WEIGHTS,
content_types=["application/x-text"],
response_types=["application/json;verbose", "application/json"],
)
Expand Down

0 comments on commit b73a795

Please sign in to comment.