Skip to content

Commit

Permalink
breaking: require framework_version, py_version for tensorflow (#1580)
Browse files Browse the repository at this point in the history
  • Loading branch information
metrizable authored Jun 12, 2020
1 parent 9df3f5a commit dbdaf50
Show file tree
Hide file tree
Showing 12 changed files with 173 additions and 164 deletions.
1 change: 1 addition & 0 deletions doc/frameworks/tensorflow/upgrade_from_legacy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ the difference in code would be as follows:
...
source_dir="code",
framework_version="1.10.0",
py_version="py2",
train_instance_type="ml.m4.xlarge",
image_name="520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-tensorflow:1.10.0-cpu-py2",
hyperparameters={
Expand Down
5 changes: 3 additions & 2 deletions doc/frameworks/tensorflow/using_tf.rst
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ To run training job with Pipe input mode, pass in ``input_mode='Pipe'`` to your
tf_estimator = TensorFlow(entry_point='tf-train-with-pipemodedataset.py', role='SageMakerRole',
training_steps=10000, evaluation_steps=100,
train_instance_count=1, train_instance_type='ml.p2.xlarge',
framework_version='1.10.0', input_mode='Pipe')
framework_version='1.10.0', py_version='py3', input_mode='Pipe')
tf_estimator.fit('s3://bucket/path/to/training/data')
Expand Down Expand Up @@ -383,7 +383,8 @@ estimator object to create a SageMaker Endpoint:
from sagemaker.tensorflow import TensorFlow
estimator = TensorFlow(entry_point='tf-train.py', ..., train_instance_count=1,
train_instance_type='ml.c4.xlarge', framework_version='1.11')
train_instance_type='ml.c4.xlarge', framework_version='1.11',
py_version='py3')
estimator.fit(inputs)
Expand Down
34 changes: 14 additions & 20 deletions src/sagemaker/tensorflow/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,11 @@ def __init__(
Args:
py_version (str): Python version you want to use for executing your model training
code (default: 'py2').
code. Defaults to ``None``. Required unless ``image_name`` is provided.
framework_version (str): TensorFlow version you want to use for executing your model
training code. List of supported versions
training code. Defaults to ``None``. Required unless ``image_name`` is provided.
List of supported versions:
https://github.com/aws/sagemaker-python-sdk#tensorflow-sagemaker-estimators.
If not specified, this will default to 1.11.
model_dir (str): S3 location where the checkpoint data and models can be exported to
during training (default: None). It will be passed in the training script as one of
the command line arguments. If not specified, one is provided based on
Expand All @@ -81,6 +81,10 @@ def __init__(
Examples:
123.dkr.ecr.us-west-2.amazonaws.com/my-custom-image:1.0
custom-image:latest.
If ``framework_version`` or ``py_version`` are ``None``, then
``image_name`` is required. If also ``None``, then a ``ValueError``
will be raised.
distributions (dict): A dictionary with information on how to run distributed training
(default: None). Currently we support distributed training with parameter servers
and MPI.
Expand Down Expand Up @@ -114,18 +118,13 @@ def __init__(
:class:`~sagemaker.estimator.Framework` and
:class:`~sagemaker.estimator.EstimatorBase`.
"""
if framework_version is None:
logger.warning(
fw.empty_framework_version_warning(defaults.TF_VERSION, self.LATEST_VERSION)
)
self.framework_version = framework_version or defaults.TF_VERSION

if not py_version:
py_version = "py3" if self._only_python_3_supported() else "py2"
fw.validate_version_or_image_args(framework_version, py_version, image_name)
if py_version == "py2":
logger.warning(
fw.python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION)
)
self.framework_version = framework_version
self.py_version = py_version

if distributions is not None:
logger.warning(fw.parameter_v2_rename_warning("distribution", distributions))
Expand All @@ -136,32 +135,27 @@ def __init__(

if "enable_sagemaker_metrics" not in kwargs:
# enable sagemaker metrics for TF v1.15 or greater:
if fw.is_version_equal_or_higher([1, 15], self.framework_version):
if framework_version and fw.is_version_equal_or_higher([1, 15], framework_version):
kwargs["enable_sagemaker_metrics"] = True

super(TensorFlow, self).__init__(image_name=image_name, **kwargs)

self.py_version = py_version
self.model_dir = model_dir
self.distributions = distributions or {}

self._validate_args(py_version=py_version, framework_version=self.framework_version)
self._validate_args(py_version=py_version)

def _validate_args(self, py_version, framework_version):
def _validate_args(self, py_version):
"""Placeholder docstring"""

if py_version == "py3":
if framework_version is None:
raise AttributeError(fw.EMPTY_FRAMEWORK_VERSION_ERROR)

if py_version == "py2" and self._only_python_3_supported():
msg = (
"Python 2 containers are only available with {} and lower versions. "
"Please use a Python 3 container.".format(defaults.LATEST_PY2_VERSION)
)
raise AttributeError(msg)

if self._only_legacy_mode_supported() and self.image_name is None:
if self.image_name is None and self._only_legacy_mode_supported():
legacy_image_uri = fw.create_image_uri(
self.sagemaker_session.boto_region_name,
"tensorflow",
Expand Down
18 changes: 13 additions & 5 deletions src/sagemaker/tensorflow/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from sagemaker.content_types import CONTENT_TYPE_JSON
from sagemaker.fw_utils import create_image_uri
from sagemaker.predictor import json_serializer, json_deserializer
from sagemaker.tensorflow.defaults import TF_VERSION


class TensorFlowPredictor(sagemaker.RealTimePredictor):
Expand Down Expand Up @@ -138,7 +137,7 @@ def __init__(
role,
entry_point=None,
image=None,
framework_version=TF_VERSION,
framework_version=None,
container_log_level=None,
predictor_cls=TensorFlowPredictor,
**kwargs
Expand All @@ -158,9 +157,12 @@ def __init__(
hosting. If ``source_dir`` is specified, then ``entry_point``
must point to a file located at the root of ``source_dir``.
image (str): A Docker image URI (default: None). If not specified, a
default image for TensorFlow Serving will be used.
default image for TensorFlow Serving will be used. If
``framework_version`` is ``None``, then ``image`` is required.
If also ``None``, then a ``ValueError`` will be raised.
framework_version (str): Optional. TensorFlow Serving version you
want to use.
want to use. Defaults to ``None``. Required unless ``image`` is
provided.
container_log_level (int): Log level to use within the container
(default: logging.ERROR). Valid values are defined in the Python
logging module.
Expand All @@ -176,6 +178,13 @@ def __init__(
:class:`~sagemaker.model.FrameworkModel` and
:class:`~sagemaker.model.Model`.
"""
if framework_version is None and image is None:
raise ValueError(
"Both framework_version and image were None. "
"Either specify framework_version or specify image_name."
)
self.framework_version = framework_version

super(TensorFlowModel, self).__init__(
model_data=model_data,
role=role,
Expand All @@ -184,7 +193,6 @@ def __init__(
entry_point=entry_point,
**kwargs
)
self.framework_version = framework_version
self._container_log_level = container_log_level

def deploy(
Expand Down
27 changes: 27 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,16 @@ def tf_version(request):
return request.param


@pytest.fixture(scope="module", params=["py2", "py3"])
def tf_py_version(tf_version, request):
version = [int(val) for val in tf_version.split(".")]
if version < [1, 11]:
return "py2"
if version < [2, 2]:
return request.param
return "py37"


@pytest.fixture(scope="module", params=["0.10.1", "0.10.1", "0.11", "0.11.0", "0.11.1"])
def rl_coach_tf_version(request):
return request.param
Expand Down Expand Up @@ -290,6 +300,23 @@ def tf_full_version(request):
return tf_version


@pytest.fixture(scope="module")
def tf_full_py_version(tf_full_version, request):
"""fixture to match tf_full_version
Fixture exists as such, since tf_full_version may be overridden --tf-full-version.
Otherwise, this would simply be py37 to match the latest version support.
TODO: Evaluate use of --tf-full-version with possible eye to remove and simplify code.
"""
version = [int(val) for val in tf_full_version.split(".")]
if version < [1, 11]:
return "py2"
if tf_full_version in [TensorFlow._LATEST_1X_VERSION, LATEST_VERSION]:
return "py37"
return "py3"


@pytest.fixture(scope="module", params=["1.15.0", "2.0.0"])
def ei_tf_full_version(request):
tf_ei_version = request.config.getoption("--ei-tf-full-version")
Expand Down
2 changes: 1 addition & 1 deletion tests/integ/test_airflow_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,7 @@ def test_tf_airflow_config_uploads_data_source_to_s3(sagemaker_session, cpu_inst
train_instance_type=cpu_instance_type,
sagemaker_session=sagemaker_session,
framework_version=TensorFlow.LATEST_VERSION,
py_version=PYTHON_VERSION,
py_version="py37", # only version available with 2.2.0
metric_definitions=[
{"Name": "train:global_steps", "Regex": r"global_step\/sec:\s(.*)"}
],
Expand Down
32 changes: 12 additions & 20 deletions tests/integ/test_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@
from sagemaker.utils import unique_name_from_base, sagemaker_timestamp

import tests.integ
from tests.integ import timeout
from tests.integ import kms_utils
from tests.integ import kms_utils, timeout, PYTHON_VERSION
from tests.integ.retry import retries
from tests.integ.s3_utils import assert_s3_files_exist

Expand All @@ -40,13 +39,8 @@
TAGS = [{"Key": "some-key", "Value": "some-value"}]


@pytest.fixture(scope="module")
def py_version(tf_full_version, tf_serving_version):
return "py37" if tf_full_version == tf_serving_version else tests.integ.PYTHON_VERSION


def test_mnist_with_checkpoint_config(
sagemaker_session, instance_type, tf_full_version, py_version
sagemaker_session, instance_type, tf_full_version, tf_full_py_version
):
checkpoint_s3_uri = "s3://{}/checkpoints/tf-{}".format(
sagemaker_session.default_bucket(), sagemaker_timestamp()
Expand All @@ -59,7 +53,7 @@ def test_mnist_with_checkpoint_config(
train_instance_type=instance_type,
sagemaker_session=sagemaker_session,
framework_version=tf_full_version,
py_version="py37",
py_version=tf_full_py_version,
metric_definitions=[{"Name": "train:global_steps", "Regex": r"global_step\/sec:\s(.*)"}],
checkpoint_s3_uri=checkpoint_s3_uri,
checkpoint_local_path=checkpoint_local_path,
Expand Down Expand Up @@ -89,7 +83,7 @@ def test_mnist_with_checkpoint_config(
assert actual_training_checkpoint_config == expected_training_checkpoint_config


def test_server_side_encryption(sagemaker_session, tf_serving_version, py_version):
def test_server_side_encryption(sagemaker_session, tf_serving_version):
with kms_utils.bucket_with_encryption(sagemaker_session, ROLE) as (bucket_with_kms, kms_key):
output_path = os.path.join(
bucket_with_kms, "test-server-side-encryption", time.strftime("%y%m%d-%H%M")
Expand All @@ -103,7 +97,7 @@ def test_server_side_encryption(sagemaker_session, tf_serving_version, py_versio
train_instance_type="ml.c5.xlarge",
sagemaker_session=sagemaker_session,
framework_version=tf_serving_version,
py_version=py_version,
py_version=PYTHON_VERSION,
code_location=output_path,
output_path=output_path,
model_dir="/opt/ml/model",
Expand All @@ -130,15 +124,15 @@ def test_server_side_encryption(sagemaker_session, tf_serving_version, py_versio


@pytest.mark.canary_quick
def test_mnist_distributed(sagemaker_session, instance_type, tf_full_version, py_version):
def test_mnist_distributed(sagemaker_session, instance_type, tf_full_version, tf_full_py_version):
estimator = TensorFlow(
entry_point=SCRIPT,
role=ROLE,
train_instance_count=2,
train_instance_type=instance_type,
sagemaker_session=sagemaker_session,
py_version="py37",
framework_version=tf_full_version,
py_version=tf_full_py_version,
distributions=PARAMETER_SERVER_DISTRIBUTION,
)
inputs = estimator.sagemaker_session.upload_data(
Expand All @@ -154,13 +148,13 @@ def test_mnist_distributed(sagemaker_session, instance_type, tf_full_version, py
)


def test_mnist_async(sagemaker_session, cpu_instance_type, tf_full_version, py_version):
def test_mnist_async(sagemaker_session, cpu_instance_type):
estimator = TensorFlow(
entry_point=SCRIPT,
role=ROLE,
train_instance_count=1,
train_instance_type="ml.c5.4xlarge",
py_version=tests.integ.PYTHON_VERSION,
py_version=PYTHON_VERSION,
sagemaker_session=sagemaker_session,
# testing py-sdk functionality, no need to run against all TF versions
framework_version=LATEST_SERVING_VERSION,
Expand Down Expand Up @@ -195,18 +189,16 @@ def test_mnist_async(sagemaker_session, cpu_instance_type, tf_full_version, py_v
_assert_model_name_match(sagemaker_session.sagemaker_client, endpoint_name, model_name)


def test_deploy_with_input_handlers(
sagemaker_session, instance_type, tf_serving_version, py_version
):
def test_deploy_with_input_handlers(sagemaker_session, instance_type, tf_serving_version):
estimator = TensorFlow(
entry_point="training.py",
source_dir=TFS_RESOURCE_PATH,
role=ROLE,
train_instance_count=1,
train_instance_type=instance_type,
py_version=py_version,
sagemaker_session=sagemaker_session,
framework_version=tf_serving_version,
py_version=PYTHON_VERSION,
sagemaker_session=sagemaker_session,
tags=TAGS,
)

Expand Down
6 changes: 3 additions & 3 deletions tests/integ/test_tf_efs_fsx.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
FSX_DIR_PATH = "/fsx/tensorflow"
MAX_JOBS = 2
MAX_PARALLEL_JOBS = 2
PY_VERSION = "py3"
PY_VERSION = "py37"


@pytest.fixture(scope="module")
Expand Down Expand Up @@ -139,8 +139,8 @@ def test_tuning_tf_efs(efs_fsx_setup, sagemaker_session, cpu_instance_type):
train_instance_count=1,
train_instance_type=cpu_instance_type,
sagemaker_session=sagemaker_session,
py_version=PY_VERSION,
framework_version=TensorFlow.LATEST_VERSION,
py_version=PY_VERSION,
subnets=subnets,
security_group_ids=security_group_ids,
)
Expand Down Expand Up @@ -186,8 +186,8 @@ def test_tuning_tf_lustre(efs_fsx_setup, sagemaker_session, cpu_instance_type):
train_instance_count=1,
train_instance_type=cpu_instance_type,
sagemaker_session=sagemaker_session,
py_version=PY_VERSION,
framework_version=TensorFlow.LATEST_VERSION,
py_version=PY_VERSION,
subnets=subnets,
security_group_ids=security_group_ids,
)
Expand Down
Loading

0 comments on commit dbdaf50

Please sign in to comment.