Skip to content

Commit

Permalink
Implement CuratedHub APIs (aws#1449)
Browse files Browse the repository at this point in the history
* Implement CuratedHub Admin APIs

* making some parameters optional in create_hub_content_reference as per the API design

* add describe_hub and list_hubs APIs

* implement delete_hub API

* Implement list_hub_contents API

* create CuratedHub class and supported utils

* implement list_models and address comments

* Add unit tests

* add describe_model function

* cache retrieval for describeHubContent changes

* fix curated hub class unit tests

* add utils needed for curatedHub

* Cache retrieval

* implement get_hub_model_reference()

* cleanup HUB type datatype

* cleanup constants

* rename list_public_models to list_jumpstart_service_hub_models

* implement describe_model_reference

* Rename CuratedHub to Hub

* address nit

* address nits and fix failing tests

---------

Co-authored-by: Malav Shastri <[email protected]>
  • Loading branch information
malav-shastri and Malav Shastri authored May 29, 2024
1 parent 32d44fb commit f227999
Show file tree
Hide file tree
Showing 19 changed files with 2,500 additions and 98 deletions.
121 changes: 94 additions & 27 deletions src/sagemaker/jumpstart/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import datetime
from difflib import get_close_matches
import os
from typing import List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union
import json
import boto3
import botocore
Expand All @@ -42,12 +42,19 @@
JUMPSTART_DEFAULT_SEMANTIC_VERSION_CACHE_EXPIRATION_HORIZON,
)
from sagemaker.jumpstart.types import (
JumpStartCachedS3ContentKey,
JumpStartCachedS3ContentValue,
JumpStartCachedContentKey,
JumpStartCachedContentValue,
JumpStartModelHeader,
JumpStartModelSpecs,
JumpStartS3FileType,
JumpStartVersionedModelId,
HubType,
HubContentType
)
from sagemaker.jumpstart.hub import utils as hub_utils
from sagemaker.jumpstart.hub.interfaces import (
DescribeHubResponse,
DescribeHubContentResponse,
)
from sagemaker.jumpstart.enums import JumpStartModelType
from sagemaker.jumpstart import utils
Expand Down Expand Up @@ -104,7 +111,7 @@ def __init__(
s3_bucket_name=s3_bucket_name, s3_client=s3_client
)

self._s3_cache = LRUCache[JumpStartCachedS3ContentKey, JumpStartCachedS3ContentValue](
self._content_cache = LRUCache[JumpStartCachedContentKey, JumpStartCachedContentValue](
max_cache_items=max_s3_cache_items,
expiration_horizon=s3_cache_expiration_horizon,
retrieval_function=self._retrieval_function,
Expand Down Expand Up @@ -230,8 +237,8 @@ def _model_id_retrieval_function(

model_id, version = key.model_id, key.version
sm_version = utils.get_sagemaker_version()
manifest = self._s3_cache.get(
JumpStartCachedS3ContentKey(
manifest = self._content_cache.get(
JumpStartCachedContentKey(
MODEL_TYPE_TO_MANIFEST_MAP[model_type], self._manifest_file_s3_map[model_type]
)
)[0].formatted_content
Expand Down Expand Up @@ -392,53 +399,87 @@ def _get_json_file_from_local_override(

def _retrieval_function(
self,
key: JumpStartCachedS3ContentKey,
value: Optional[JumpStartCachedS3ContentValue],
) -> JumpStartCachedS3ContentValue:
"""Return s3 content given a file type and s3_key in ``JumpStartCachedS3ContentKey``.
key: JumpStartCachedContentKey,
value: Optional[JumpStartCachedContentValue],
) -> JumpStartCachedContentValue:
"""Return s3 content given a file type and s3_key in ``JumpStartCachedContentKey``.
If a manifest file is being fetched, we only download the object if the md5 hash in
``head_object`` does not match the current md5 hash for the stored value. This prevents
unnecessarily downloading the full manifest when it hasn't changed.
Args:
key (JumpStartCachedS3ContentKey): key for which to fetch s3 content.
key (JumpStartCachedContentKey): key for which to fetch s3 content.
value (Optional[JumpStartVersionedModelId]): Current value of old cached
s3 content. This is used for the manifest file, so that it is only
downloaded when its content changes.
"""

file_type, s3_key = key.file_type, key.s3_key
if file_type in {
data_type, id_info = key.data_type, key.id_info

if data_type in {
JumpStartS3FileType.OPEN_WEIGHT_MANIFEST,
JumpStartS3FileType.PROPRIETARY_MANIFEST,
}:
if value is not None and not self._is_local_metadata_mode():
etag = self._get_json_md5_hash(s3_key)
etag = self._get_json_md5_hash(id_info)
if etag == value.md5_hash:
return value
formatted_body, etag = self._get_json_file(s3_key, file_type)
return JumpStartCachedS3ContentValue(
formatted_body, etag = self._get_json_file(id_info, data_type)
return JumpStartCachedContentValue(
formatted_content=utils.get_formatted_manifest(formatted_body),
md5_hash=etag,
)
if file_type in {
if data_type in {
JumpStartS3FileType.OPEN_WEIGHT_SPECS,
JumpStartS3FileType.PROPRIETARY_SPECS,
}:
formatted_body, _ = self._get_json_file(s3_key, file_type)
formatted_body, _ = self._get_json_file(id_info, data_type)
model_specs = JumpStartModelSpecs(formatted_body)
utils.emit_logs_based_on_model_specs(model_specs, self.get_region(), self._s3_client)
return JumpStartCachedS3ContentValue(formatted_content=model_specs)
raise ValueError(self._file_type_error_msg(file_type))
return JumpStartCachedContentValue(
formatted_content=model_specs
)

if data_type == HubContentType.NOTEBOOK:
hub_name, _, notebook_name, notebook_version = hub_utils \
.get_info_from_hub_resource_arn(id_info)
response: Dict[str, Any] = self._sagemaker_session.describe_hub_content(
hub_name=hub_name,
hub_content_name=notebook_name,
hub_content_version=notebook_version,
hub_content_type=data_type,
)
hub_notebook_description = DescribeHubContentResponse(response)
return JumpStartCachedContentValue(formatted_content=hub_notebook_description)

if data_type in [HubContentType.MODEL, HubContentType.MODEL_REFERENCE]:
hub_name, _, model_name, model_version = hub_utils.get_info_from_hub_resource_arn(
id_info
)
hub_model_description: Dict[str, Any] = self._sagemaker_session.describe_hub_content(
hub_name=hub_name,
hub_content_name=model_name,
hub_content_version=model_version,
hub_content_type=data_type,
)

model_specs = make_model_specs_from_describe_hub_content_response(
DescribeHubContentResponse(hub_model_description),
)

utils.emit_logs_based_on_model_specs(model_specs, self.get_region(), self._s3_client)
return JumpStartCachedContentValue(formatted_content=model_specs)

raise ValueError(self._file_type_error_msg(data_type))

def get_manifest(
self,
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
) -> List[JumpStartModelHeader]:
"""Return entire JumpStart models manifest."""
manifest_dict = self._s3_cache.get(
JumpStartCachedS3ContentKey(
manifest_dict = self._content_cache.get(
JumpStartCachedContentKey(
MODEL_TYPE_TO_MANIFEST_MAP[model_type], self._manifest_file_s3_map[model_type]
)
)[0].formatted_content
Expand Down Expand Up @@ -525,8 +566,8 @@ def _get_header_impl(
JumpStartVersionedModelId(model_id, semantic_version_str)
)[0]

manifest = self._s3_cache.get(
JumpStartCachedS3ContentKey(
manifest = self._content_cache.get(
JumpStartCachedContentKey(
MODEL_TYPE_TO_MANIFEST_MAP[model_type], self._manifest_file_s3_map[model_type]
)
)[0].formatted_content
Expand Down Expand Up @@ -556,18 +597,44 @@ def get_specs(
"""
header = self.get_header(model_id, version_str, model_type)
spec_key = header.spec_key
specs, cache_hit = self._s3_cache.get(
JumpStartCachedS3ContentKey(MODEL_TYPE_TO_SPECS_MAP[model_type], spec_key)
specs, cache_hit = self._content_cache.get(
JumpStartCachedContentKey(MODEL_TYPE_TO_SPECS_MAP[model_type], spec_key)
)

if not cache_hit and "*" in version_str:
JUMPSTART_LOGGER.warning(
get_wildcard_model_version_msg(header.model_id, version_str, header.version)
)
return specs.formatted_content

def get_hub_model(self, hub_model_arn: str) -> JumpStartModelSpecs:
"""Return JumpStart-compatible specs for a given Hub model
Args:
hub_model_arn (str): Arn for the Hub model to get specs for
"""

details, _ = self._content_cache.get(JumpStartCachedContentKey(
HubContentType.MODEL,
hub_model_arn,
))
return details.formatted_content

def get_hub_model_reference(self, hub_model_arn: str) -> JumpStartModelSpecs:
"""Return JumpStart-compatible specs for a given Hub model reference
Args:
hub_model_arn (str): Arn for the Hub model to get specs for
"""

details, _ = self._content_cache.get(JumpStartCachedContentKey(
HubContentType.MODEL_REFERENCE,
hub_model_arn,
))
return details.formatted_content

def clear(self) -> None:
"""Clears the model ID/version and s3 cache."""
self._s3_cache.clear()
self._content_cache.clear()
self._open_weight_model_id_manifest_key_cache.clear()
self._proprietary_model_id_manifest_key_cache.clear()
3 changes: 3 additions & 0 deletions src/sagemaker/jumpstart/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,9 @@
JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY = "models_manifest.json"
JUMPSTART_DEFAULT_PROPRIETARY_MANIFEST_KEY = "proprietary-sdk-manifest.json"

HUB_CONTENT_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):hub-content/(.*?)/(.*?)/(.*?)/(.*?)$"
HUB_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):hub/(.*?)$"

INFERENCE_ENTRY_POINT_SCRIPT_NAME = "inference.py"
TRAINING_ENTRY_POINT_SCRIPT_NAME = "transfer_learning.py"

Expand Down
Empty file.
18 changes: 18 additions & 0 deletions src/sagemaker/jumpstart/hub/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""This module stores constants related to SageMaker JumpStart CuratedHub."""
from __future__ import absolute_import

JUMPSTART_MODEL_HUB_NAME = "JumpStartServiceHub"

LATEST_VERSION_WILDCARD = "*"
Loading

0 comments on commit f227999

Please sign in to comment.