Skip to content

Commit

Permalink
change: add XGBoost support to image_uris.retrieve() (#1714)
Browse files Browse the repository at this point in the history
  • Loading branch information
laurenyu authored Jul 16, 2020
1 parent 211f4e5 commit 3a90f94
Show file tree
Hide file tree
Showing 9 changed files with 357 additions and 58 deletions.
122 changes: 122 additions & 0 deletions src/sagemaker/image_uri_config/xgboost.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
{
"scope": ["inference", "training"],
"version_aliases": {
"latest": "1"
},
"versions": {
"1": {
"registries": {
"ap-east-1": "286214385809",
"ap-northeast-1": "501404015308",
"ap-northeast-2": "306986355934",
"ap-south-1": "991648021394",
"ap-southeast-1": "475088953585",
"ap-southeast-2": "544295431143",
"ca-central-1": "469771592824",
"cn-north-1": "390948362332",
"cn-northwest-1": "387376663083",
"eu-central-1": "813361260812",
"eu-north-1": "669576153137",
"eu-west-1": "685385470294",
"eu-west-2": "644912444149",
"eu-west-3": "749696950732",
"me-south-1": "249704162688",
"sa-east-1": "855470959533",
"us-east-1": "811284229777",
"us-east-2": "825641698319",
"us-gov-west-1": "226302683700",
"us-iso-east-1": "490574956308",
"us-west-1": "632365934929",
"us-west-2": "433757028032"
},
"repository": "xgboost"
},
"0.90-1": {
"processors": ["cpu"],
"py_versions": ["py3"],
"registries": {
"ap-east-1": "651117190479",
"ap-northeast-1": "354813040037",
"ap-northeast-2": "366743142698",
"ap-south-1": "720646828776",
"ap-southeast-1": "121021644041",
"ap-southeast-2": "783357654285",
"ca-central-1": "341280168497",
"cn-north-1": "450853457545",
"cn-northwest-1": "451049120500",
"eu-central-1": "492215442770",
"eu-north-1": "662702820516",
"eu-west-1": "141502667606",
"eu-west-2": "764974769150",
"eu-west-3": "659782779980",
"me-south-1": "801668240914",
"sa-east-1": "737474898029",
"us-east-1": "683313688378",
"us-east-2": "257758044811",
"us-gov-west-1": "414596584902",
"us-iso-east-1": "833128469047",
"us-west-1": "746614075791",
"us-west-2": "246618743249"
},
"repository": "sagemaker-xgboost"
},
"0.90-2": {
"processors": ["cpu"],
"py_versions": ["py3"],
"registries": {
"ap-east-1": "651117190479",
"ap-northeast-1": "354813040037",
"ap-northeast-2": "366743142698",
"ap-south-1": "720646828776",
"ap-southeast-1": "121021644041",
"ap-southeast-2": "783357654285",
"ca-central-1": "341280168497",
"cn-north-1": "450853457545",
"cn-northwest-1": "451049120500",
"eu-central-1": "492215442770",
"eu-north-1": "662702820516",
"eu-west-1": "141502667606",
"eu-west-2": "764974769150",
"eu-west-3": "659782779980",
"me-south-1": "801668240914",
"sa-east-1": "737474898029",
"us-east-1": "683313688378",
"us-east-2": "257758044811",
"us-gov-west-1": "414596584902",
"us-iso-east-1": "833128469047",
"us-west-1": "746614075791",
"us-west-2": "246618743249"
},
"repository": "sagemaker-xgboost"
},
"1.0-1": {
"processors": ["cpu"],
"py_versions": ["py3"],
"registries": {
"ap-east-1": "651117190479",
"ap-northeast-1": "354813040037",
"ap-northeast-2": "366743142698",
"ap-south-1": "720646828776",
"ap-southeast-1": "121021644041",
"ap-southeast-2": "783357654285",
"ca-central-1": "341280168497",
"cn-north-1": "450853457545",
"cn-northwest-1": "451049120500",
"eu-central-1": "492215442770",
"eu-north-1": "662702820516",
"eu-west-1": "141502667606",
"eu-west-2": "764974769150",
"eu-west-3": "659782779980",
"me-south-1": "801668240914",
"sa-east-1": "737474898029",
"us-east-1": "683313688378",
"us-east-2": "257758044811",
"us-gov-west-1": "414596584902",
"us-iso-east-1": "833128469047",
"us-west-1": "746614075791",
"us-west-2": "246618743249"
},
"repository": "sagemaker-xgboost"
}
}
}
14 changes: 12 additions & 2 deletions src/sagemaker/image_uris.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,12 @@ def retrieve(
registry = _registry_from_region(region, version_config["registries"])
hostname = utils._botocore_resolver().construct_endpoint("ecr", region)["hostname"]

processor = _processor(
instance_type, config.get("processors") or version_config.get("processors")
)
tag = _format_tag(version, processor, py_version)

repo = version_config["repository"]
tag = _format_tag(version, _processor(instance_type, config.get("processors")), py_version)

return ECR_URI_TEMPLATE.format(registry=registry, hostname=hostname, repository=repo, tag=tag)

Expand Down Expand Up @@ -138,11 +142,17 @@ def _processor(instance_type, available_processors):
logger.info("Ignoring unnecessary instance type: %s.", instance_type)
return None

if not instance_type:
raise ValueError(
"Empty SageMaker instance type. For options, see: "
"https://aws.amazon.com/sagemaker/pricing/instance-types"
)

if instance_type.startswith("local"):
processor = "cpu" if instance_type == "local" else "gpu"
elif not instance_type.startswith("ml."):
raise ValueError(
"Invalid SageMaker instance type: {}. See: "
"Invalid SageMaker instance type: {}. For options, see: "
"https://aws.amazon.com/sagemaker/pricing/instance-types".format(instance_type)
)
else:
Expand Down
10 changes: 6 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,9 +159,11 @@ def sklearn_version(request):
return request.param


@pytest.fixture(scope="module", params=["0.90-1"])
def xgboost_version(request):
return request.param
@pytest.fixture(scope="module")
def xgboost_framework_version(xgboost_version):
if xgboost_version in ("1", "latest"):
pytest.skip("Skipping XGBoost algorithm version.")
return xgboost_version


@pytest.fixture(scope="module", params=["py2", "py3"])
Expand Down Expand Up @@ -351,7 +353,7 @@ def pytest_generate_tests(metafunc):


def _generate_all_framework_version_fixtures(metafunc):
for fw in ("chainer", "tensorflow"):
for fw in ("chainer", "tensorflow", "xgboost"):
config = image_uris.config_for_framework(fw)
if "scope" in config:
_parametrize_framework_version_fixtures(metafunc, fw, config)
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/sagemaker/image_uris/expected_uris.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,6 @@ def framework_uri(repo, fw_version, account, py_version=None, processor="cpu", r
return IMAGE_URI_FORMAT.format(account, region, domain, repo, tag)


def algo_uri(algo, account, region):
def algo_uri(algo, account, region, version=1):
domain = ALTERNATE_DOMAINS.get(region, DOMAIN)
return IMAGE_URI_FORMAT.format(account, region, domain, algo, 1)
return IMAGE_URI_FORMAT.format(account, region, domain, algo, version)
22 changes: 22 additions & 0 deletions tests/unit/sagemaker/image_uris/regions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import boto3


def regions():
boto_session = boto3.Session()
for partition in boto_session.get_available_partitions():
for region in boto_session.get_available_regions("sagemaker", partition_name=partition):
yield region
13 changes: 2 additions & 11 deletions tests/unit/sagemaker/image_uris/test_algos.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,8 @@
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import boto3

from sagemaker import image_uris
from tests.unit.sagemaker.image_uris import expected_uris
from tests.unit.sagemaker.image_uris import expected_uris, regions

ALGO_REGIONS_AND_ACCOUNTS = (
{
Expand Down Expand Up @@ -60,13 +58,6 @@
IMAGE_URI_FORMAT = "{}.dkr.ecr.{}.{}/{}:1"


def _regions():
boto_session = boto3.Session()
for partition in boto_session.get_available_partitions():
for region in boto_session.get_available_regions("sagemaker", partition_name=partition):
yield region


def _accounts_for_algo(algo):
for algo_account_dict in ALGO_REGIONS_AND_ACCOUNTS:
if algo in algo_account_dict["algorithms"]:
Expand All @@ -79,7 +70,7 @@ def test_factorization_machines():
algo = "factorization-machines"
accounts = _accounts_for_algo(algo)

for region in _regions():
for region in regions.regions():
for scope in ("training", "inference"):
uri = image_uris.retrieve(algo, region, image_scope=scope)
assert expected_uris.algo_uri(algo, accounts[region], region) == uri
39 changes: 39 additions & 0 deletions tests/unit/sagemaker/image_uris/test_retrieve.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,34 @@ def test_retrieve_processor_type(config_for_framework):
assert "123412341234.dkr.ecr.us-west-2.amazonaws.com/dummy:1.0.0-gpu-py3" == uri


@patch("sagemaker.image_uris.config_for_framework")
def test_retrieve_processor_type_from_version_specific_processor_config(config_for_framework):
config = copy.deepcopy(BASE_CONFIG)
del config["processors"]
config["versions"]["1.0.0"]["processors"] = ["cpu"]
config_for_framework.return_value = config

uri = image_uris.retrieve(
framework="useless-string",
version="1.0.0",
py_version="py3",
instance_type="ml.c4.xlarge",
region="us-west-2",
image_scope="training",
)
assert "123412341234.dkr.ecr.us-west-2.amazonaws.com/dummy:1.0.0-cpu-py3" == uri

uri = image_uris.retrieve(
framework="useless-string",
version="1.1.0",
py_version="py3",
instance_type="ml.c4.xlarge",
region="us-west-2",
image_scope="training",
)
assert "123412341234.dkr.ecr.us-west-2.amazonaws.com/dummy:1.1.0-py3" == uri


@patch("sagemaker.image_uris.config_for_framework", return_value=BASE_CONFIG)
def test_retrieve_unsupported_processor_type(config_for_framework):
with pytest.raises(ValueError) as e:
Expand All @@ -388,6 +416,17 @@ def test_retrieve_unsupported_processor_type(config_for_framework):

assert "Invalid SageMaker instance type: not-an-instance-type." in str(e.value)

with pytest.raises(ValueError) as e:
image_uris.retrieve(
framework="useless-string",
version="1.0.0",
py_version="py3",
region="us-west-2",
image_scope="training",
)

assert "Empty SageMaker instance type." in str(e.value)

config = copy.deepcopy(BASE_CONFIG)
config["processors"] = ["cpu"]
config_for_framework.return_value = config
Expand Down
Loading

0 comments on commit 3a90f94

Please sign in to comment.