From bd28ca3cb9f8e343780e8bb18792adb34bbbc446 Mon Sep 17 00:00:00 2001 From: Lauren Yu <6631887+laurenyu@users.noreply.github.com> Date: Thu, 16 Jul 2020 13:14:51 -0700 Subject: [PATCH] change: add MXNet configuration to image_uris.retrieve() (#1716) --- src/sagemaker/image_uri_config/mxnet.json | 656 ++++++++++++++++++ tests/conftest.py | 2 +- .../image_uris/test_dlc_frameworks.py | 305 ++++---- 3 files changed, 843 insertions(+), 120 deletions(-) create mode 100644 src/sagemaker/image_uri_config/mxnet.json diff --git a/src/sagemaker/image_uri_config/mxnet.json b/src/sagemaker/image_uri_config/mxnet.json new file mode 100644 index 0000000000..b878a70c31 --- /dev/null +++ b/src/sagemaker/image_uri_config/mxnet.json @@ -0,0 +1,656 @@ +{ + "training": { + "processors": ["cpu", "gpu"], + "version_aliases": { + "0.12": "0.12.1", + "1.0": "1.0.0", + "1.1": "1.1.0", + "1.2": "1.2.1", + "1.3": "1.3.0", + "1.4": "1.4.1", + "1.6": "1.6.0" + }, + "versions": { + "0.12.1": { + "registries": { + "ap-east-1": "057415533634", + "ap-northeast-1": "520713654638", + "ap-northeast-2": "520713654638", + "ap-south-1": "520713654638", + "ap-southeast-1": "520713654638", + "ap-southeast-2": "520713654638", + "ca-central-1": "520713654638", + "cn-north-1": "422961961927", + "cn-northwest-1": "423003514399", + "eu-central-1": "520713654638", + "eu-north-1": "520713654638", + "eu-west-1": "520713654638", + "eu-west-2": "520713654638", + "eu-west-3": "520713654638", + "me-south-1": "724002660598", + "sa-east-1": "520713654638", + "us-east-1": "520713654638", + "us-east-2": "520713654638", + "us-gov-west-1": "246785580436", + "us-iso-east-1": "744548109606", + "us-west-1": "520713654638", + "us-west-2": "520713654638" + }, + "repository": "sagemaker-mxnet", + "py_versions": ["py2", "py3"] + }, + "1.0.0": { + "registries": { + "ap-east-1": "057415533634", + "ap-northeast-1": "520713654638", + "ap-northeast-2": "520713654638", + "ap-south-1": "520713654638", + "ap-southeast-1": "520713654638", + "ap-southeast-2": "520713654638", + "ca-central-1": "520713654638", + "cn-north-1": "422961961927", + "cn-northwest-1": "423003514399", + "eu-central-1": "520713654638", + "eu-north-1": "520713654638", + "eu-west-1": "520713654638", + "eu-west-2": "520713654638", + "eu-west-3": "520713654638", + "me-south-1": "724002660598", + "sa-east-1": "520713654638", + "us-east-1": "520713654638", + "us-east-2": "520713654638", + "us-gov-west-1": "246785580436", + "us-iso-east-1": "744548109606", + "us-west-1": "520713654638", + "us-west-2": "520713654638" + }, + "repository": "sagemaker-mxnet", + "py_versions": ["py2", "py3"] + }, + "1.1.0": { + "registries": { + "ap-east-1": "057415533634", + "ap-northeast-1": "520713654638", + "ap-northeast-2": "520713654638", + "ap-south-1": "520713654638", + "ap-southeast-1": "520713654638", + "ap-southeast-2": "520713654638", + "ca-central-1": "520713654638", + "cn-north-1": "422961961927", + "cn-northwest-1": "423003514399", + "eu-central-1": "520713654638", + "eu-north-1": "520713654638", + "eu-west-1": "520713654638", + "eu-west-2": "520713654638", + "eu-west-3": "520713654638", + "me-south-1": "724002660598", + "sa-east-1": "520713654638", + "us-east-1": "520713654638", + "us-east-2": "520713654638", + "us-gov-west-1": "246785580436", + "us-iso-east-1": "744548109606", + "us-west-1": "520713654638", + "us-west-2": "520713654638" + }, + "repository": "sagemaker-mxnet", + "py_versions": ["py2", "py3"] + }, + "1.2.1": { + "registries": { + "ap-east-1": "057415533634", + "ap-northeast-1": "520713654638", + "ap-northeast-2": "520713654638", + "ap-south-1": "520713654638", + "ap-southeast-1": "520713654638", + "ap-southeast-2": "520713654638", + "ca-central-1": "520713654638", + "cn-north-1": "422961961927", + "cn-northwest-1": "423003514399", + "eu-central-1": "520713654638", + "eu-north-1": "520713654638", + "eu-west-1": "520713654638", + "eu-west-2": "520713654638", + "eu-west-3": "520713654638", + "me-south-1": "724002660598", + "sa-east-1": "520713654638", + "us-east-1": "520713654638", + "us-east-2": "520713654638", + "us-gov-west-1": "246785580436", + "us-iso-east-1": "744548109606", + "us-west-1": "520713654638", + "us-west-2": "520713654638" + }, + "repository": "sagemaker-mxnet", + "py_versions": ["py2", "py3"] + }, + "1.3.0": { + "registries": { + "ap-east-1": "057415533634", + "ap-northeast-1": "520713654638", + "ap-northeast-2": "520713654638", + "ap-south-1": "520713654638", + "ap-southeast-1": "520713654638", + "ap-southeast-2": "520713654638", + "ca-central-1": "520713654638", + "cn-north-1": "422961961927", + "cn-northwest-1": "423003514399", + "eu-central-1": "520713654638", + "eu-north-1": "520713654638", + "eu-west-1": "520713654638", + "eu-west-2": "520713654638", + "eu-west-3": "520713654638", + "me-south-1": "724002660598", + "sa-east-1": "520713654638", + "us-east-1": "520713654638", + "us-east-2": "520713654638", + "us-gov-west-1": "246785580436", + "us-iso-east-1": "744548109606", + "us-west-1": "520713654638", + "us-west-2": "520713654638" + }, + "repository": "sagemaker-mxnet", + "py_versions": ["py2", "py3"] + }, + "1.4.0": { + "registries": { + "ap-east-1": "057415533634", + "ap-northeast-1": "520713654638", + "ap-northeast-2": "520713654638", + "ap-south-1": "520713654638", + "ap-southeast-1": "520713654638", + "ap-southeast-2": "520713654638", + "ca-central-1": "520713654638", + "cn-north-1": "422961961927", + "cn-northwest-1": "423003514399", + "eu-central-1": "520713654638", + "eu-north-1": "520713654638", + "eu-west-1": "520713654638", + "eu-west-2": "520713654638", + "eu-west-3": "520713654638", + "me-south-1": "724002660598", + "sa-east-1": "520713654638", + "us-east-1": "520713654638", + "us-east-2": "520713654638", + "us-gov-west-1": "246785580436", + "us-iso-east-1": "744548109606", + "us-west-1": "520713654638", + "us-west-2": "520713654638" + }, + "repository": "sagemaker-mxnet", + "py_versions": ["py2", "py3"] + }, + "1.4.1": { + "py2": { + "registries": { + "ap-east-1": "057415533634", + "ap-northeast-1": "520713654638", + "ap-northeast-2": "520713654638", + "ap-south-1": "520713654638", + "ap-southeast-1": "520713654638", + "ap-southeast-2": "520713654638", + "ca-central-1": "520713654638", + "cn-north-1": "422961961927", + "cn-northwest-1": "423003514399", + "eu-central-1": "520713654638", + "eu-north-1": "520713654638", + "eu-west-1": "520713654638", + "eu-west-2": "520713654638", + "eu-west-3": "520713654638", + "me-south-1": "724002660598", + "sa-east-1": "520713654638", + "us-east-1": "520713654638", + "us-east-2": "520713654638", + "us-gov-west-1": "246785580436", + "us-iso-east-1": "744548109606", + "us-west-1": "520713654638", + "us-west-2": "520713654638" + }, + "repository": "sagemaker-mxnet" + }, + "py3": { + "registries": { + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-south-1": "763104351884", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "me-south-1": "217643126080", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "repository": "mxnet-training" + } + }, + "1.6.0": { + "registries": { + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-south-1": "763104351884", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "me-south-1": "217643126080", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "repository": "mxnet-training", + "py_versions": ["py2", "py3"] + } + } + }, + "inference": { + "processors": ["cpu", "gpu"], + "version_aliases": { + "0.12": "0.12.1", + "1.0": "1.0.0", + "1.1": "1.1.0", + "1.2": "1.2.1", + "1.3": "1.3.0", + "1.4": "1.4.1", + "1.6": "1.6.0" + }, + "versions": { + "0.12.1": { + "registries": { + "ap-east-1": "057415533634", + "ap-northeast-1": "520713654638", + "ap-northeast-2": "520713654638", + "ap-south-1": "520713654638", + "ap-southeast-1": "520713654638", + "ap-southeast-2": "520713654638", + "ca-central-1": "520713654638", + "cn-north-1": "422961961927", + "cn-northwest-1": "423003514399", + "eu-central-1": "520713654638", + "eu-north-1": "520713654638", + "eu-west-1": "520713654638", + "eu-west-2": "520713654638", + "eu-west-3": "520713654638", + "me-south-1": "724002660598", + "sa-east-1": "520713654638", + "us-east-1": "520713654638", + "us-east-2": "520713654638", + "us-gov-west-1": "246785580436", + "us-iso-east-1": "744548109606", + "us-west-1": "520713654638", + "us-west-2": "520713654638" + }, + "repository": "sagemaker-mxnet", + "py_versions": ["py2", "py3"] + }, + "1.0.0": { + "registries": { + "ap-east-1": "057415533634", + "ap-northeast-1": "520713654638", + "ap-northeast-2": "520713654638", + "ap-south-1": "520713654638", + "ap-southeast-1": "520713654638", + "ap-southeast-2": "520713654638", + "ca-central-1": "520713654638", + "cn-north-1": "422961961927", + "cn-northwest-1": "423003514399", + "eu-central-1": "520713654638", + "eu-north-1": "520713654638", + "eu-west-1": "520713654638", + "eu-west-2": "520713654638", + "eu-west-3": "520713654638", + "me-south-1": "724002660598", + "sa-east-1": "520713654638", + "us-east-1": "520713654638", + "us-east-2": "520713654638", + "us-gov-west-1": "246785580436", + "us-iso-east-1": "744548109606", + "us-west-1": "520713654638", + "us-west-2": "520713654638" + }, + "repository": "sagemaker-mxnet", + "py_versions": ["py2", "py3"] + }, + "1.1.0": { + "registries": { + "ap-east-1": "057415533634", + "ap-northeast-1": "520713654638", + "ap-northeast-2": "520713654638", + "ap-south-1": "520713654638", + "ap-southeast-1": "520713654638", + "ap-southeast-2": "520713654638", + "ca-central-1": "520713654638", + "cn-north-1": "422961961927", + "cn-northwest-1": "423003514399", + "eu-central-1": "520713654638", + "eu-north-1": "520713654638", + "eu-west-1": "520713654638", + "eu-west-2": "520713654638", + "eu-west-3": "520713654638", + "me-south-1": "724002660598", + "sa-east-1": "520713654638", + "us-east-1": "520713654638", + "us-east-2": "520713654638", + "us-gov-west-1": "246785580436", + "us-iso-east-1": "744548109606", + "us-west-1": "520713654638", + "us-west-2": "520713654638" + }, + "repository": "sagemaker-mxnet", + "py_versions": ["py2", "py3"] + }, + "1.2.1": { + "registries": { + "ap-east-1": "057415533634", + "ap-northeast-1": "520713654638", + "ap-northeast-2": "520713654638", + "ap-south-1": "520713654638", + "ap-southeast-1": "520713654638", + "ap-southeast-2": "520713654638", + "ca-central-1": "520713654638", + "cn-north-1": "422961961927", + "cn-northwest-1": "423003514399", + "eu-central-1": "520713654638", + "eu-north-1": "520713654638", + "eu-west-1": "520713654638", + "eu-west-2": "520713654638", + "eu-west-3": "520713654638", + "me-south-1": "724002660598", + "sa-east-1": "520713654638", + "us-east-1": "520713654638", + "us-east-2": "520713654638", + "us-gov-west-1": "246785580436", + "us-iso-east-1": "744548109606", + "us-west-1": "520713654638", + "us-west-2": "520713654638" + }, + "repository": "sagemaker-mxnet", + "py_versions": ["py2", "py3"] + }, + "1.3.0": { + "registries": { + "ap-east-1": "057415533634", + "ap-northeast-1": "520713654638", + "ap-northeast-2": "520713654638", + "ap-south-1": "520713654638", + "ap-southeast-1": "520713654638", + "ap-southeast-2": "520713654638", + "ca-central-1": "520713654638", + "cn-north-1": "422961961927", + "cn-northwest-1": "423003514399", + "eu-central-1": "520713654638", + "eu-north-1": "520713654638", + "eu-west-1": "520713654638", + "eu-west-2": "520713654638", + "eu-west-3": "520713654638", + "me-south-1": "724002660598", + "sa-east-1": "520713654638", + "us-east-1": "520713654638", + "us-east-2": "520713654638", + "us-gov-west-1": "246785580436", + "us-iso-east-1": "744548109606", + "us-west-1": "520713654638", + "us-west-2": "520713654638" + }, + "repository": "sagemaker-mxnet", + "py_versions": ["py2", "py3"] + }, + "1.4.0": { + "registries": { + "ap-east-1": "057415533634", + "ap-northeast-1": "520713654638", + "ap-northeast-2": "520713654638", + "ap-south-1": "520713654638", + "ap-southeast-1": "520713654638", + "ap-southeast-2": "520713654638", + "ca-central-1": "520713654638", + "cn-north-1": "422961961927", + "cn-northwest-1": "423003514399", + "eu-central-1": "520713654638", + "eu-north-1": "520713654638", + "eu-west-1": "520713654638", + "eu-west-2": "520713654638", + "eu-west-3": "520713654638", + "me-south-1": "724002660598", + "sa-east-1": "520713654638", + "us-east-1": "520713654638", + "us-east-2": "520713654638", + "us-gov-west-1": "246785580436", + "us-iso-east-1": "744548109606", + "us-west-1": "520713654638", + "us-west-2": "520713654638" + }, + "repository": "sagemaker-mxnet-serving", + "py_versions": ["py2", "py3"] + }, + "1.4.1": { + "py2": { + "registries": { + "ap-east-1": "057415533634", + "ap-northeast-1": "520713654638", + "ap-northeast-2": "520713654638", + "ap-south-1": "520713654638", + "ap-southeast-1": "520713654638", + "ap-southeast-2": "520713654638", + "ca-central-1": "520713654638", + "cn-north-1": "422961961927", + "cn-northwest-1": "423003514399", + "eu-central-1": "520713654638", + "eu-north-1": "520713654638", + "eu-west-1": "520713654638", + "eu-west-2": "520713654638", + "eu-west-3": "520713654638", + "me-south-1": "724002660598", + "sa-east-1": "520713654638", + "us-east-1": "520713654638", + "us-east-2": "520713654638", + "us-gov-west-1": "246785580436", + "us-iso-east-1": "744548109606", + "us-west-1": "520713654638", + "us-west-2": "520713654638" + }, + "repository": "sagemaker-mxnet-serving" + }, + "py3": { + "registries": { + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-south-1": "763104351884", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "me-south-1": "217643126080", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "repository": "mxnet-inference" + } + }, + "1.6.0": { + "registries": { + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-south-1": "763104351884", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "me-south-1": "217643126080", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "repository": "mxnet-inference", + "py_versions": ["py2", "py3"] + } + } + }, + "eia": { + "processors": ["cpu"], + "version_aliases": { + "1.3": "1.3.0", + "1.4": "1.4.1", + "1.5": "1.5.1" + }, + "versions": { + "1.3.0": { + "registries": { + "ap-east-1": "057415533634", + "ap-northeast-1": "520713654638", + "ap-northeast-2": "520713654638", + "ap-south-1": "520713654638", + "ap-southeast-1": "520713654638", + "ap-southeast-2": "520713654638", + "ca-central-1": "520713654638", + "cn-north-1": "422961961927", + "cn-northwest-1": "423003514399", + "eu-central-1": "520713654638", + "eu-north-1": "520713654638", + "eu-west-1": "520713654638", + "eu-west-2": "520713654638", + "eu-west-3": "520713654638", + "me-south-1": "724002660598", + "sa-east-1": "520713654638", + "us-east-1": "520713654638", + "us-east-2": "520713654638", + "us-gov-west-1": "246785580436", + "us-iso-east-1": "744548109606", + "us-west-1": "520713654638", + "us-west-2": "520713654638" + }, + "repository": "sagemaker-mxnet-eia", + "py_versions": ["py2", "py3"] + }, + "1.4.0": { + "registries": { + "ap-east-1": "057415533634", + "ap-northeast-1": "520713654638", + "ap-northeast-2": "520713654638", + "ap-south-1": "520713654638", + "ap-southeast-1": "520713654638", + "ap-southeast-2": "520713654638", + "ca-central-1": "520713654638", + "cn-north-1": "422961961927", + "cn-northwest-1": "423003514399", + "eu-central-1": "520713654638", + "eu-north-1": "520713654638", + "eu-west-1": "520713654638", + "eu-west-2": "520713654638", + "eu-west-3": "520713654638", + "me-south-1": "724002660598", + "sa-east-1": "520713654638", + "us-east-1": "520713654638", + "us-east-2": "520713654638", + "us-gov-west-1": "246785580436", + "us-iso-east-1": "744548109606", + "us-west-1": "520713654638", + "us-west-2": "520713654638" + }, + "repository": "sagemaker-mxnet-serving-eia", + "py_versions": ["py2", "py3"] + }, + "1.4.1": { + "registries": { + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-south-1": "763104351884", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "me-south-1": "217643126080", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "repository": "mxnet-inference-eia", + "py_versions": ["py2", "py3"] + }, + "1.5.1": { + "registries": { + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-south-1": "763104351884", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "me-south-1": "217643126080", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "repository": "mxnet-inference-eia", + "py_versions": ["py2", "py3"] + } + } + } +} diff --git a/tests/conftest.py b/tests/conftest.py index 43f34c67dd..97faa9c1d6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -353,7 +353,7 @@ def pytest_generate_tests(metafunc): def _generate_all_framework_version_fixtures(metafunc): - for fw in ("chainer", "tensorflow", "xgboost"): + for fw in ("chainer", "mxnet", "tensorflow", "xgboost"): config = image_uris.config_for_framework(fw) if "scope" in config: _parametrize_framework_version_fixtures(metafunc, fw, config) diff --git a/tests/unit/sagemaker/image_uris/test_dlc_frameworks.py b/tests/unit/sagemaker/image_uris/test_dlc_frameworks.py index 25dd636674..be5b5f4e8c 100644 --- a/tests/unit/sagemaker/image_uris/test_dlc_frameworks.py +++ b/tests/unit/sagemaker/image_uris/test_dlc_frameworks.py @@ -40,77 +40,71 @@ } -def test_chainer(chainer_version, chainer_py_version): - for instance_type, processor in INSTANCE_TYPES_AND_PROCESSORS: - for scope in ("training", "inference"): - uri = image_uris.retrieve( - framework="chainer", - region=REGION, - version=chainer_version, - py_version=chainer_py_version, - instance_type=instance_type, - image_scope=scope, - ) - expected = expected_uris.framework_uri( - repo="sagemaker-chainer", - fw_version=chainer_version, - py_version=chainer_py_version, - account=SAGEMAKER_ACCOUNT, - processor=processor, - ) - assert expected == uri - - for region, account in SAGEMAKER_ALTERNATE_REGION_ACCOUNTS.items(): - uri = image_uris.retrieve( - framework="chainer", - region=region, - version=chainer_version, - py_version=chainer_py_version, - instance_type="ml.c4.xlarge", - image_scope="training", - ) - expected = expected_uris.framework_uri( - repo="sagemaker-chainer", - fw_version=chainer_version, - py_version=chainer_py_version, - account=account, - region=region, - ) - assert expected == uri - +def _test_image_uris(framework, fw_version, py_version, scope, expected_fn, expected_fn_args): + base_args = { + "framework": framework, + "version": fw_version, + "py_version": py_version, + "image_scope": scope, + } -def test_tensorflow_training(tensorflow_training_version, tensorflow_training_py_version): for instance_type, processor in INSTANCE_TYPES_AND_PROCESSORS: - uri = image_uris.retrieve( - framework="tensorflow", - region=REGION, - version=tensorflow_training_version, - py_version=tensorflow_training_py_version, - instance_type=instance_type, - image_scope="training", - ) + uri = image_uris.retrieve(region=REGION, instance_type=instance_type, **base_args) - expected = _expected_tf_training_uri( - tensorflow_training_version, tensorflow_training_py_version, processor=processor - ) + expected = expected_fn(processor=processor, **expected_fn_args) assert expected == uri for region in SAGEMAKER_ALTERNATE_REGION_ACCOUNTS.keys(): - uri = image_uris.retrieve( - framework="tensorflow", - region=region, - version=tensorflow_training_version, - py_version=tensorflow_training_py_version, - instance_type="ml.c4.xlarge", - image_scope="training", - ) + uri = image_uris.retrieve(region=region, instance_type="ml.c4.xlarge", **base_args) - expected = _expected_tf_training_uri( - tensorflow_training_version, tensorflow_training_py_version, region=region - ) + expected = expected_fn(region=region, **expected_fn_args) assert expected == uri +def test_chainer(chainer_version, chainer_py_version): + expected_fn_args = { + "chainer_version": chainer_version, + "py_version": chainer_py_version, + } + + _test_image_uris( + "chainer", + chainer_version, + chainer_py_version, + "training", + _expected_chainer_uri, + expected_fn_args, + ) + + +def _expected_chainer_uri(chainer_version, py_version, processor="cpu", region=REGION): + account = SAGEMAKER_ACCOUNT if region == REGION else SAGEMAKER_ALTERNATE_REGION_ACCOUNTS[region] + return expected_uris.framework_uri( + repo="sagemaker-chainer", + fw_version=chainer_version, + py_version=py_version, + processor=processor, + region=region, + account=account, + ) + + +def test_tensorflow_training(tensorflow_training_version, tensorflow_training_py_version): + expected_fn_args = { + "tf_training_version": tensorflow_training_version, + "py_version": tensorflow_training_py_version, + } + + _test_image_uris( + "tensorflow", + tensorflow_training_version, + tensorflow_training_py_version, + "training", + _expected_tf_training_uri, + expected_fn_args, + ) + + def _expected_tf_training_uri(tf_training_version, py_version, processor="cpu", region=REGION): version = Version(tf_training_version) if version < Version("1.11"): @@ -122,17 +116,10 @@ def _expected_tf_training_uri(tf_training_version, py_version, processor="cpu", else: repo = "sagemaker-tensorflow-scriptmode" if py_version == "py2" else "tensorflow-training" - if repo.startswith("sagemaker"): - account = ( - SAGEMAKER_ACCOUNT if region == REGION else SAGEMAKER_ALTERNATE_REGION_ACCOUNTS[region] - ) - else: - account = DLC_ACCOUNT if region == REGION else DLC_ALTERNATE_REGION_ACCOUNTS[region] - return expected_uris.framework_uri( repo, tf_training_version, - account, + _sagemaker_or_dlc_account(repo, region), py_version=py_version, processor=processor, region=region, @@ -140,57 +127,33 @@ def _expected_tf_training_uri(tf_training_version, py_version, processor="cpu", def test_tensorflow_inference(tensorflow_inference_version): - for instance_type, processor in INSTANCE_TYPES_AND_PROCESSORS: - uri = image_uris.retrieve( - framework="tensorflow", - region=REGION, - version=tensorflow_inference_version, - py_version="py2", - instance_type=instance_type, - image_scope="inference", - ) - - expected = _expected_tf_inference_uri(tensorflow_inference_version, processor=processor) - assert expected == uri - - for region in SAGEMAKER_ALTERNATE_REGION_ACCOUNTS.keys(): - uri = image_uris.retrieve( - framework="tensorflow", - region=region, - version=tensorflow_inference_version, - py_version="py2", - instance_type="ml.c4.xlarge", - image_scope="inference", - ) - - expected = _expected_tf_inference_uri(tensorflow_inference_version, region=region) - assert expected == uri + _test_image_uris( + "tensorflow", + tensorflow_inference_version, + "py2", + "inference", + _expected_tf_inference_uri, + {"tf_inference_version": tensorflow_inference_version}, + ) def test_tensorflow_eia(tensorflow_eia_version): - uri = image_uris.retrieve( - framework="tensorflow", - region=REGION, - version=tensorflow_eia_version, - py_version="py2", - instance_type="ml.c4.xlarge", - accelerator_type="ml.eia1.medium", - image_scope="inference", - ) + base_args = { + "framework": "tensorflow", + "version": tensorflow_eia_version, + "py_version": "py2", + "instance_type": "ml.c4.xlarge", + "accelerator_type": "ml.eia1.medium", + "image_scope": "inference", + } + + uri = image_uris.retrieve(region=REGION, **base_args) expected = _expected_tf_inference_uri(tensorflow_eia_version, eia=True) assert expected == uri for region in SAGEMAKER_ALTERNATE_REGION_ACCOUNTS.keys(): - uri = image_uris.retrieve( - framework="tensorflow", - region=region, - version=tensorflow_eia_version, - py_version="py2", - instance_type="ml.c4.xlarge", - accelerator_type="ml.eia1.medium", - image_scope="inference", - ) + uri = image_uris.retrieve(region=region, **base_args) expected = _expected_tf_inference_uri(tensorflow_eia_version, region=region, eia=True) assert expected == uri @@ -201,13 +164,7 @@ def _expected_tf_inference_uri(tf_inference_version, processor="cpu", region=REG repo = _expected_tf_inference_repo(version, eia) py_version = "py2" if version < Version("1.11") else None - if repo.startswith("sagemaker"): - account = ( - SAGEMAKER_ACCOUNT if region == REGION else SAGEMAKER_ALTERNATE_REGION_ACCOUNTS[region] - ) - else: - account = DLC_ACCOUNT if region == REGION else DLC_ALTERNATE_REGION_ACCOUNTS[region] - + account = _sagemaker_or_dlc_account(repo, region) return expected_uris.framework_uri( repo, tf_inference_version, account, py_version, processor=processor, region=region, ) @@ -225,3 +182,113 @@ def _expected_tf_inference_repo(version, eia): repo = "-".join((repo, "eia")) return repo + + +def test_mxnet_training(mxnet_training_version, mxnet_py_version): + expected_fn_args = { + "mxnet_version": mxnet_training_version, + "py_version": mxnet_py_version, + } + + _test_image_uris( + "mxnet", + mxnet_training_version, + mxnet_py_version, + "training", + _expected_mxnet_training_uri, + expected_fn_args, + ) + + +def _expected_mxnet_training_uri(mxnet_version, py_version, processor="cpu", region=REGION): + version = Version(mxnet_version) + if version < Version("1.4") or mxnet_version == "1.4.0": + repo = "sagemaker-mxnet" + elif version >= Version("1.6.0"): + repo = "mxnet-training" + else: + repo = "sagemaker-mxnet" if py_version == "py2" else "mxnet-training" + + return expected_uris.framework_uri( + repo, + mxnet_version, + _sagemaker_or_dlc_account(repo, region), + py_version=py_version, + processor=processor, + region=region, + ) + + +def test_mxnet_inference(mxnet_inference_version, mxnet_py_version): + expected_fn_args = { + "mxnet_version": mxnet_inference_version, + "py_version": mxnet_py_version, + } + + _test_image_uris( + "mxnet", + mxnet_inference_version, + mxnet_py_version, + "inference", + _expected_mxnet_inference_uri, + expected_fn_args, + ) + + +def test_mxnet_eia(mxnet_eia_version, mxnet_py_version): + base_args = { + "framework": "mxnet", + "version": mxnet_eia_version, + "py_version": mxnet_py_version, + "image_scope": "inference", + "instance_type": "ml.c4.xlarge", + "accelerator_type": "ml.eia1.medium", + } + + uri = image_uris.retrieve(region=REGION, **base_args) + + expected = _expected_mxnet_inference_uri(mxnet_eia_version, mxnet_py_version, eia=True) + assert expected == uri + + for region in SAGEMAKER_ALTERNATE_REGION_ACCOUNTS.keys(): + uri = image_uris.retrieve(region=region, **base_args) + + expected = _expected_mxnet_inference_uri( + mxnet_eia_version, mxnet_py_version, region=region, eia=True + ) + assert expected == uri + + +def _expected_mxnet_inference_uri( + mxnet_version, py_version, processor="cpu", region=REGION, eia=False +): + version = Version(mxnet_version) + if version < Version("1.4"): + repo = "sagemaker-mxnet" + elif mxnet_version == "1.4.0": + repo = "sagemaker-mxnet-serving" + elif version >= Version("1.5"): + repo = "mxnet-inference" + else: + repo = "sagemaker-mxnet-serving" if py_version == "py2" and not eia else "mxnet-inference" + + if eia: + repo = "-".join((repo, "eia")) + + return expected_uris.framework_uri( + repo, + mxnet_version, + _sagemaker_or_dlc_account(repo, region), + py_version=py_version, + processor=processor, + region=region, + ) + + +def _sagemaker_or_dlc_account(repo, region): + if repo.startswith("sagemaker"): + return ( + SAGEMAKER_ACCOUNT if region == REGION else SAGEMAKER_ALTERNATE_REGION_ACCOUNTS[region] + ) + else: + return DLC_ACCOUNT if region == REGION else DLC_ALTERNATE_REGION_ACCOUNTS[region]