Skip to content

Commit

Permalink
breaking: preserve script path when S3 source_dir is provided (#941)
Browse files Browse the repository at this point in the history
  • Loading branch information
laurenyu authored Jul 16, 2020
1 parent db21a38 commit 211f4e5
Show file tree
Hide file tree
Showing 13 changed files with 69 additions and 29 deletions.
2 changes: 1 addition & 1 deletion src/sagemaker/chainer/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def create_model(
return ChainerModel(
self.model_data,
role or self.role,
entry_point or self.entry_point,
entry_point or self._model_entry_point(),
source_dir=(source_dir or self._model_source_dir()),
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics,
container_log_level=self.container_log_level,
Expand Down
19 changes: 15 additions & 4 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1734,17 +1734,28 @@ def _stage_user_code_in_s3(self):
)

def _model_source_dir(self):
"""Get the appropriate value to pass as source_dir to model constructor
on deploying
"""Get the appropriate value to pass as ``source_dir`` to a model constructor.
Returns:
str: Either a local or an S3 path pointing to the source_dir to be
used for code by the model to be deployed
str: Either a local or an S3 path pointing to the ``source_dir`` to be
used for code by the model to be deployed
"""
return (
self.source_dir if self.sagemaker_session.local_mode else self.uploaded_code.s3_prefix
)

def _model_entry_point(self):
"""Get the appropriate value to pass as ``entry_point`` to a model constructor.
Returns:
str: The path to the entry point script. This can be either an absolute path or
a path relative to ``self._model_source_dir()``.
"""
if self.sagemaker_session.local_mode or (self._model_source_dir() is None):
return self.entry_point

return self.uploaded_code.script_name

def hyperparameters(self):
"""Return the hyperparameters as a dictionary to use for training.
Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/fw_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,7 @@ def tar_and_upload_dir(
script name.
"""
if directory and directory.lower().startswith("s3://"):
return UploadedCode(s3_prefix=directory, script_name=os.path.basename(script))
return UploadedCode(s3_prefix=directory, script_name=script)

script_name = script if directory else os.path.basename(script)
dependencies = dependencies or []
Expand Down
11 changes: 9 additions & 2 deletions src/sagemaker/mxnet/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,10 +218,10 @@ def create_model(

kwargs["name"] = self._get_or_create_name(kwargs.get("name"))

return MXNetModel(
model = MXNetModel(
self.model_data,
role or self.role,
entry_point or self.entry_point,
entry_point,
framework_version=self.framework_version,
py_version=self.py_version,
source_dir=(source_dir or self._model_source_dir()),
Expand All @@ -235,6 +235,13 @@ def create_model(
**kwargs
)

if entry_point is None:
model.entry_point = (
self.entry_point if model._is_mms_version() else self._model_entry_point()
)

return model

@classmethod
def _prepare_init_params_from_job_description(cls, job_details, model_channel_name=None):
"""Convert the job description to init params that can be handled by the
Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/pytorch/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def create_model(
return PyTorchModel(
self.model_data,
role or self.role,
entry_point or self.entry_point,
entry_point or self._model_entry_point(),
framework_version=self.framework_version,
py_version=self.py_version,
source_dir=(source_dir or self._model_source_dir()),
Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/rl/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def create_model(
if not entry_point and (source_dir or dependencies):
raise AttributeError("Please provide an `entry_point`.")

entry_point = entry_point or self.entry_point
entry_point = entry_point or self._model_entry_point()
source_dir = source_dir or self._model_source_dir()
dependencies = dependencies or self.dependencies

Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/sklearn/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def create_model(
return SKLearnModel(
self.model_data,
role,
entry_point or self.entry_point,
entry_point or self._model_entry_point(),
source_dir=(source_dir or self._model_source_dir()),
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics,
container_log_level=self.container_log_level,
Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/xgboost/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def create_model(
return XGBoostModel(
self.model_data,
role,
entry_point or self.entry_point,
entry_point or self._model_entry_point(),
framework_version=self.framework_version,
source_dir=(source_dir or self._model_source_dir()),
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics,
Expand Down
Binary file added tests/data/mxnet_mnist/sourcedir.tar.gz
Binary file not shown.
21 changes: 16 additions & 5 deletions tests/integ/test_mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,16 @@ def mxnet_training_job(
sagemaker_session, mxnet_full_version, mxnet_full_py_version, cpu_instance_type
):
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
script_path = os.path.join(DATA_DIR, "mxnet_mnist", "mnist.py")
s3_prefix = "integ-test-data/mxnet_mnist"
data_path = os.path.join(DATA_DIR, "mxnet_mnist")

s3_source = sagemaker_session.upload_data(
path=os.path.join(data_path, "sourcedir.tar.gz"), key_prefix="{}/src".format(s3_prefix)
)

mx = MXNet(
entry_point=script_path,
entry_point="mxnet_mnist/mnist.py",
source_dir=s3_source,
role="SageMakerRole",
framework_version=mxnet_full_version,
py_version=mxnet_full_py_version,
Expand All @@ -46,10 +51,10 @@ def mxnet_training_job(
)

train_input = mx.sagemaker_session.upload_data(
path=os.path.join(data_path, "train"), key_prefix="integ-test-data/mxnet_mnist/train"
path=os.path.join(data_path, "train"), key_prefix="{}/train".format(s3_prefix)
)
test_input = mx.sagemaker_session.upload_data(
path=os.path.join(data_path, "test"), key_prefix="integ-test-data/mxnet_mnist/test"
path=os.path.join(data_path, "test"), key_prefix="{}/test".format(s3_prefix)
)

mx.fit({"train": train_input, "test": test_input})
Expand All @@ -62,7 +67,13 @@ def test_attach_deploy(mxnet_training_job, sagemaker_session, cpu_instance_type)

with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
estimator = MXNet.attach(mxnet_training_job, sagemaker_session=sagemaker_session)
predictor = estimator.deploy(1, cpu_instance_type, endpoint_name=endpoint_name)
predictor = estimator.deploy(
1,
cpu_instance_type,
entry_point="mnist.py",
source_dir=os.path.join(DATA_DIR, "mxnet_mnist"),
endpoint_name=endpoint_name,
)
data = numpy.zeros(shape=(1, 1, 28, 28))
result = predictor.predict(data)
assert result is not None
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1269,7 +1269,7 @@ def test_git_support_codecommit_with_ssh_no_passphrase_needed(git_clone_repo, sa
@patch("time.strftime", return_value=TIMESTAMP)
def test_init_with_source_dir_s3(strftime, sagemaker_session):
fw = DummyFramework(
entry_point=SCRIPT_PATH,
entry_point=SCRIPT_NAME,
source_dir="s3://location",
role=ROLE,
sagemaker_session=sagemaker_session,
Expand Down
12 changes: 12 additions & 0 deletions tests/unit/test_fw_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -873,6 +873,18 @@ def test_tar_and_upload_dir_s3(sagemaker_session):
assert result == fw_utils.UploadedCode("s3://m", "mnist.py")


def test_tar_and_upload_dir_s3_with_script_dir(sagemaker_session):
bucket = "mybucket"
s3_key_prefix = "something/source"
script = "some/dir/mnist.py"
directory = "s3://m"
result = fw_utils.tar_and_upload_dir(
sagemaker_session, bucket, s3_key_prefix, script, directory
)

assert result == fw_utils.UploadedCode("s3://m", "some/dir/mnist.py")


@patch("sagemaker.utils")
def test_tar_and_upload_dir_s3_with_kms(utils, sagemaker_session):
bucket = "mybucket"
Expand Down
21 changes: 10 additions & 11 deletions tests/unit/test_mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
from sagemaker.mxnet import MXNetPredictor, MXNetModel

DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data")
SCRIPT_PATH = os.path.join(DATA_DIR, "dummy_script.py")
SCRIPT_NAME = "dummy_script.py"
SCRIPT_PATH = os.path.join(DATA_DIR, SCRIPT_NAME)
SERVING_SCRIPT_FILE = "another_dummy_script.py"
MODEL_DATA = "s3://mybucket/model"
ENV = {"DUMMY_ENV_VAR": "dummy_value"}
Expand Down Expand Up @@ -189,7 +190,8 @@ def test_create_model(name_from_base, sagemaker_session, mxnet_version, mxnet_py
base_job_name = "job"

mx = MXNet(
entry_point=SCRIPT_PATH,
entry_point=SCRIPT_NAME,
source_dir=source_dir,
framework_version=mxnet_version,
py_version=mxnet_py_version,
role=ROLE,
Expand All @@ -198,7 +200,6 @@ def test_create_model(name_from_base, sagemaker_session, mxnet_version, mxnet_py
instance_type=INSTANCE_TYPE,
container_log_level=container_log_level,
base_job_name=base_job_name,
source_dir=source_dir,
)

mx.fit(inputs="s3://mybucket/train", job_name="new_name")
Expand All @@ -210,7 +211,7 @@ def test_create_model(name_from_base, sagemaker_session, mxnet_version, mxnet_py
assert model.sagemaker_session == sagemaker_session
assert model.framework_version == mxnet_version
assert model.py_version == mxnet_py_version
assert model.entry_point == SCRIPT_PATH
assert model.entry_point == SCRIPT_NAME
assert model.role == ROLE
assert model.name == model_name
assert model.container_log_level == container_log_level
Expand All @@ -226,7 +227,8 @@ def test_create_model_with_optional_params(sagemaker_session, mxnet_version, mxn
source_dir = "s3://mybucket/source"
enable_cloudwatch_metrics = "true"
mx = MXNet(
entry_point=SCRIPT_PATH,
entry_point=SCRIPT_NAME,
source_dir=source_dir,
framework_version=mxnet_version,
py_version=mxnet_py_version,
role=ROLE,
Expand All @@ -235,7 +237,6 @@ def test_create_model_with_optional_params(sagemaker_session, mxnet_version, mxn
instance_type=INSTANCE_TYPE,
container_log_level=container_log_level,
base_job_name="job",
source_dir=source_dir,
enable_cloudwatch_metrics=enable_cloudwatch_metrics,
)

Expand Down Expand Up @@ -270,7 +271,8 @@ def test_create_model_with_custom_image(name_from_base, sagemaker_session):
base_job_name = "job"

mx = MXNet(
entry_point=SCRIPT_PATH,
entry_point=SCRIPT_NAME,
source_dir=source_dir,
framework_version="2.0",
py_version="py3",
role=ROLE,
Expand All @@ -280,7 +282,6 @@ def test_create_model_with_custom_image(name_from_base, sagemaker_session):
image_uri=custom_image,
container_log_level=container_log_level,
base_job_name=base_job_name,
source_dir=source_dir,
)

mx.fit(inputs="s3://mybucket/train", job_name="new_name")
Expand All @@ -291,7 +292,7 @@ def test_create_model_with_custom_image(name_from_base, sagemaker_session):

assert model.sagemaker_session == sagemaker_session
assert model.image_uri == custom_image
assert model.entry_point == SCRIPT_PATH
assert model.entry_point == SCRIPT_NAME
assert model.role == ROLE
assert model.name == model_name
assert model.container_log_level == container_log_level
Expand Down Expand Up @@ -730,7 +731,6 @@ def test_model_py2_warning(warning, sagemaker_session):

def test_create_model_with_custom_hosting_image(sagemaker_session):
container_log_level = '"logging.INFO"'
source_dir = "s3://mybucket/source"
custom_image = "mxnet:2.0"
custom_hosting_image = "mxnet_hosting:2.0"
mx = MXNet(
Expand All @@ -744,7 +744,6 @@ def test_create_model_with_custom_hosting_image(sagemaker_session):
image_uri=custom_image,
container_log_level=container_log_level,
base_job_name="job",
source_dir=source_dir,
)

mx.fit(inputs="s3://mybucket/train", job_name="new_name")
Expand Down

0 comments on commit 211f4e5

Please sign in to comment.