Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support config_name in all JumpStart interfaces #4583

Merged
merged 12 commits into from
Apr 22, 2024
3 changes: 3 additions & 0 deletions src/sagemaker/accept_types.py
Original file line number Diff line number Diff line change
@@ -77,6 +77,7 @@ def retrieve_default(
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
config_name: Optional[str] = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in the future, we should think of a way to consolidate all these JS related fields perhaps into a dataclass or kwargs, so we don't need to update all these function prototypes whenever a new feature is added.

) -> str:
"""Retrieves the default accept type for the model matching the given arguments.

@@ -98,6 +99,7 @@ def retrieve_default(
object, used for SageMaker interactions. If not
specified, one is created using the default AWS configuration
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check with @judyheflin about how to describe this

Returns:
str: The default accept type to use for the model.

@@ -117,4 +119,5 @@ def retrieve_default(
tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
model_type=model_type,
config_name=config_name,
)
3 changes: 3 additions & 0 deletions src/sagemaker/content_types.py
Original file line number Diff line number Diff line change
@@ -77,6 +77,7 @@ def retrieve_default(
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
config_name: Optional[str] = None,
) -> str:
"""Retrieves the default content type for the model matching the given arguments.

@@ -98,6 +99,7 @@ def retrieve_default(
object, used for SageMaker interactions. If not
specified, one is created using the default AWS configuration
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
Returns:
str: The default content type to use for the model.

@@ -117,6 +119,7 @@ def retrieve_default(
tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
model_type=model_type,
config_name=config_name,
)


3 changes: 3 additions & 0 deletions src/sagemaker/deserializers.py
Original file line number Diff line number Diff line change
@@ -97,6 +97,7 @@ def retrieve_default(
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
config_name: Optional[str] = None,
) -> BaseDeserializer:
"""Retrieves the default deserializer for the model matching the given arguments.

@@ -118,6 +119,7 @@ def retrieve_default(
object, used for SageMaker interactions. If not
specified, one is created using the default AWS configuration
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
Returns:
BaseDeserializer: The default deserializer to use for the model.

@@ -138,4 +140,5 @@ def retrieve_default(
tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
model_type=model_type,
config_name=config_name,
)
3 changes: 3 additions & 0 deletions src/sagemaker/environment_variables.py
Original file line number Diff line number Diff line change
@@ -36,6 +36,7 @@ def retrieve_default(
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
instance_type: Optional[str] = None,
script: JumpStartScriptScope = JumpStartScriptScope.INFERENCE,
config_name: Optional[str] = None,
) -> Dict[str, str]:
"""Retrieves the default container environment variables for the model matching the arguments.

@@ -65,6 +66,7 @@ def retrieve_default(
variables specific for the instance type.
script (JumpStartScriptScope): The JumpStart script for which to retrieve environment
variables.
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
Returns:
dict: The variables to use for the model.

@@ -87,4 +89,5 @@ def retrieve_default(
sagemaker_session=sagemaker_session,
instance_type=instance_type,
script=script,
config_name=config_name,
)
3 changes: 3 additions & 0 deletions src/sagemaker/hyperparameters.py
Original file line number Diff line number Diff line change
@@ -36,6 +36,7 @@ def retrieve_default(
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
config_name: Optional[str] = None,
) -> Dict[str, str]:
"""Retrieves the default training hyperparameters for the model matching the given arguments.

@@ -66,6 +67,7 @@ def retrieve_default(
object, used for SageMaker interactions. If not
specified, one is created using the default AWS configuration
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
Returns:
dict: The hyperparameters to use for the model.

@@ -86,6 +88,7 @@ def retrieve_default(
tolerate_vulnerable_model=tolerate_vulnerable_model,
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
config_name=config_name,
)


3 changes: 3 additions & 0 deletions src/sagemaker/image_uris.py
Original file line number Diff line number Diff line change
@@ -68,6 +68,7 @@ def retrieve(
inference_tool=None,
serverless_inference_config=None,
sagemaker_session=DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
config_name=None,
) -> str:
"""Retrieves the ECR URI for the Docker image matching the given arguments.

@@ -121,6 +122,7 @@ def retrieve(
object, used for SageMaker interactions. If not
specified, one is created using the default AWS configuration
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).

Returns:
str: The ECR URI for the corresponding SageMaker Docker image.
@@ -160,6 +162,7 @@ def retrieve(
tolerate_vulnerable_model,
tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
config_name=config_name,
)

if training_compiler_config and (framework in [HUGGING_FACE_FRAMEWORK, "pytorch"]):
3 changes: 3 additions & 0 deletions src/sagemaker/instance_types.py
Original file line number Diff line number Diff line change
@@ -36,6 +36,7 @@ def retrieve_default(
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
training_instance_type: Optional[str] = None,
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
config_name: Optional[str] = None,
) -> str:
"""Retrieves the default instance type for the model matching the given arguments.

@@ -64,6 +65,7 @@ def retrieve_default(
Optionally supply this to get a inference instance type conditioned
on the training instance, to ensure compatability of training artifact to inference
instance. (Default: None).
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
Returns:
str: The default instance type to use for the model.

@@ -88,6 +90,7 @@ def retrieve_default(
sagemaker_session=sagemaker_session,
training_instance_type=training_instance_type,
model_type=model_type,
config_name=config_name,
)


8 changes: 8 additions & 0 deletions src/sagemaker/jumpstart/artifacts/environment_variables.py
Original file line number Diff line number Diff line change
@@ -39,6 +39,7 @@ def _retrieve_default_environment_variables(
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
instance_type: Optional[str] = None,
script: JumpStartScriptScope = JumpStartScriptScope.INFERENCE,
config_name: Optional[str] = None,
) -> Dict[str, str]:
"""Retrieves the inference environment variables for the model matching the given arguments.

@@ -68,6 +69,7 @@ def _retrieve_default_environment_variables(
environment variables specific for the instance type.
script (JumpStartScriptScope): The JumpStart script for which to retrieve
environment variables.
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
Returns:
dict: the inference environment variables to use for the model.
"""
@@ -84,6 +86,7 @@ def _retrieve_default_environment_variables(
tolerate_vulnerable_model=tolerate_vulnerable_model,
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
config_name=config_name,
)

default_environment_variables: Dict[str, str] = {}
@@ -121,7 +124,9 @@ def _retrieve_default_environment_variables(
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
instance_type=instance_type,
config_name=config_name,
)

)

gated_model_env_var: Optional[str] = retrieve_gated_env_var_for_instance_type(
@@ -167,6 +172,7 @@ def _retrieve_gated_model_uri_env_var_value(
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
instance_type: Optional[str] = None,
config_name: Optional[str] = None,
) -> Optional[str]:
"""Retrieves the gated model env var URI matching the given arguments.

@@ -190,6 +196,7 @@ def _retrieve_gated_model_uri_env_var_value(
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
instance_type (str): An instance type to optionally supply in order to get
environment variables specific for the instance type.
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).

Returns:
Optional[str]: the s3 URI to use for the environment variable, or None if the model does not
@@ -211,6 +218,7 @@ def _retrieve_gated_model_uri_env_var_value(
tolerate_vulnerable_model=tolerate_vulnerable_model,
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
config_name=config_name,
)

s3_key: Optional[str] = (
3 changes: 3 additions & 0 deletions src/sagemaker/jumpstart/artifacts/hyperparameters.py
Original file line number Diff line number Diff line change
@@ -36,6 +36,7 @@ def _retrieve_default_hyperparameters(
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
instance_type: Optional[str] = None,
config_name: Optional[str] = None,
):
"""Retrieves the training hyperparameters for the model matching the given arguments.

@@ -66,6 +67,7 @@ def _retrieve_default_hyperparameters(
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
instance_type (str): An instance type to optionally supply in order to get hyperparameters
specific for the instance type.
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
Returns:
dict: the hyperparameters to use for the model.
"""
@@ -82,6 +84,7 @@ def _retrieve_default_hyperparameters(
tolerate_vulnerable_model=tolerate_vulnerable_model,
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
config_name=config_name,
)

default_hyperparameters: Dict[str, str] = {}
4 changes: 4 additions & 0 deletions src/sagemaker/jumpstart/artifacts/image_uris.py
Original file line number Diff line number Diff line change
@@ -46,6 +46,7 @@ def _retrieve_image_uri(
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
config_name: Optional[str] = None,
):
"""Retrieves the container image URI for JumpStart models.

@@ -95,6 +96,7 @@ def _retrieve_image_uri(
object, used for SageMaker interactions. If not
specified, one is created using the default AWS configuration
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
Returns:
str: the ECR URI for the corresponding SageMaker Docker image.

@@ -116,6 +118,7 @@ def _retrieve_image_uri(
tolerate_vulnerable_model=tolerate_vulnerable_model,
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
config_name=config_name,
)

if image_scope == JumpStartScriptScope.INFERENCE:
@@ -200,4 +203,5 @@ def _retrieve_image_uri(
distribution=distribution,
base_framework_version=base_framework_version_override or base_framework_version,
training_compiler_config=training_compiler_config,
config_name=config_name,
)
3 changes: 3 additions & 0 deletions src/sagemaker/jumpstart/artifacts/incremental_training.py
Original file line number Diff line number Diff line change
@@ -33,6 +33,7 @@ def _model_supports_incremental_training(
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
config_name: Optional[str] = None,
) -> bool:
"""Returns True if the model supports incremental training.

@@ -54,6 +55,7 @@ def _model_supports_incremental_training(
object, used for SageMaker interactions. If not
specified, one is created using the default AWS configuration
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
Returns:
bool: the support status for incremental training.
"""
@@ -70,6 +72,7 @@ def _model_supports_incremental_training(
tolerate_vulnerable_model=tolerate_vulnerable_model,
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
config_name=config_name,
)

return model_specs.supports_incremental_training()
6 changes: 6 additions & 0 deletions src/sagemaker/jumpstart/artifacts/instance_types.py
Original file line number Diff line number Diff line change
@@ -40,6 +40,7 @@ def _retrieve_default_instance_type(
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
training_instance_type: Optional[str] = None,
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
config_name: Optional[str] = None,
) -> str:
"""Retrieves the default instance type for the model.

@@ -68,6 +69,7 @@ def _retrieve_default_instance_type(
Optionally supply this to get a inference instance type conditioned
on the training instance, to ensure compatability of training artifact to inference
instance. (Default: None).
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
Returns:
str: the default instance type to use for the model or None.

@@ -89,6 +91,7 @@ def _retrieve_default_instance_type(
tolerate_deprecated_model=tolerate_deprecated_model,
model_type=model_type,
sagemaker_session=sagemaker_session,
config_name=config_name,
)

if scope == JumpStartScriptScope.INFERENCE:
@@ -128,6 +131,7 @@ def _retrieve_instance_types(
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
training_instance_type: Optional[str] = None,
config_name: Optional[str] = None,
) -> List[str]:
"""Retrieves the supported instance types for the model.

@@ -156,6 +160,7 @@ def _retrieve_instance_types(
Optionally supply this to get a inference instance type conditioned
on the training instance, to ensure compatability of training artifact to inference
instance. (Default: None).
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
Returns:
list: the supported instance types to use for the model or None.

@@ -176,6 +181,7 @@ def _retrieve_instance_types(
tolerate_vulnerable_model=tolerate_vulnerable_model,
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
config_name=config_name,
)

if scope == JumpStartScriptScope.INFERENCE:
Loading