Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add supported inference and incremental training configs #4637

Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions src/sagemaker/jumpstart/factory/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
_retrieve_model_package_model_artifact_s3_uri,
)
from sagemaker.jumpstart.artifacts.resource_names import _retrieve_resource_name_base
from sagemaker.jumpstart.session_utils import get_model_info_from_training_job
from sagemaker.session import Session
from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig
from sagemaker.base_deserializers import BaseDeserializer
Expand Down Expand Up @@ -815,7 +814,7 @@ def _add_config_name_to_kwargs(
config_name=kwargs.config_name,
)

if specs.training_configs and specs.training_configs.get_top_config_from_ranking().config_name:
if specs.training_configs and specs.training_configs.get_top_config_from_ranking():
kwargs.config_name = (
kwargs.config_name or specs.training_configs.get_top_config_from_ranking().config_name
)
Expand Down
21 changes: 7 additions & 14 deletions src/sagemaker/jumpstart/factory/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,13 +588,9 @@ def _add_config_name_to_init_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpSta
model_type=kwargs.model_type,
config_name=kwargs.config_name,
)
if (
specs.inference_configs
and specs.inference_configs.get_top_config_from_ranking().config_name
):
kwargs.config_name = (
kwargs.config_name or specs.inference_configs.get_top_config_from_ranking().config_name
)
if specs.inference_configs and specs.inference_configs.get_top_config_from_ranking():
default_config_name = specs.inference_configs.get_top_config_from_ranking().config_name
kwargs.config_name = kwargs.config_name or default_config_name

if not kwargs.config_name:
return kwargs
Expand All @@ -614,6 +610,7 @@ def _add_config_name_to_init_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpSta

return kwargs


def _add_config_name_to_deploy_kwargs(
kwargs: JumpStartModelDeployKwargs, training_config_name: Optional[str] = None
) -> JumpStartModelInitKwargs:
Expand Down Expand Up @@ -643,13 +640,9 @@ def _add_config_name_to_deploy_kwargs(
specs=specs, training_config_name=training_config_name
)

if (
specs.inference_configs
and specs.inference_configs.get_top_config_from_ranking().config_name
):
kwargs.config_name = (
kwargs.config_name or specs.inference_configs.get_top_config_from_ranking().config_name
)
if specs.inference_configs and specs.inference_configs.get_top_config_from_ranking():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems you just got rid of .config_name from the if-statement

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry extracted to default_config_name and removed the redundant check ..get_top_config_from_ranking() in the if clause. This is not needed since it fallbacks to None anyway

default_config_name = specs.inference_configs.get_top_config_from_ranking().config_name
kwargs.config_name = kwargs.config_name or default_config_name

return kwargs

Expand Down