Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add mxnet-gluon-cifar10 example #2

Merged
merged 2 commits into from
Nov 6, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions im-python-sdk/mxnet_gluon_cifar10/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
data/
280 changes: 280 additions & 0 deletions im-python-sdk/mxnet_gluon_cifar10/cifar10.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,280 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Distributed ResNet Training with MXNet and Gluon\n",
"\n",
"[ResNet_V2](https://arxiv.org/abs/1512.03385) is an architecture for deep convolution networks. In this example, we train a 34 layer network to perform image classification using the CIFAR-10 dataset. CIFAR-10 consists of 60000 32x32 colour images in 10 classes, with 6000 images per class. There are 50000 training images and 10000 test images. \n",
"\n",
"### Setup\n",
"\n",
"This example requires the `scikit-image` library. Use jupyter's [conda tab](/tree#conda) to install it."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"import credentials # put your credentials in credentials.py \n",
"import logging \n",
"import os\n",
"import im\n",
"from im.mxnet import MXNet\n",
"from im.session import s3_input \n",
"from mxnet import gluon\n",
"\n",
"os.environ['AWS_DEFAULT_REGION']='us-west-2'\n",
"\n",
"# Session will use your default credentials\n",
"# e.g. from environment variables or your ~/.aws/credentials file.\n",
"ims = im.Session()\n",
"\n",
"# Replace with a role that gives IM access to s3 and cloudwatch\n",
"# see 1-Creating_a_role_allowing_IM_to_access_S3_Cloudwatch_ECR.ipynb\n",
"role='IMRole'"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Download training and test data\n",
"\n",
"We use the helper scripts to download CIFAR10 training data and sample images."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from cifar10_utils import download_training_data\n",
"download_training_data()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Uploading the data\n",
"\n",
"We use the `im.Session.upload_data` function to upload our datasets to an S3 location. The return value `inputs` identifies the location -- we will use this later when we start the training job."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"inputs = ims.upload_data(path='data', key_prefix='data/gluon-cifar10')\n",
"print('input spec (in this case, just an S3 path): {}'.format(inputs))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Implement the training function\n",
"\n",
"We need to provide a training script that can run on the IM platform. The training scripts are essentially the same as one you would write for local training, except that you need to provide a `train` function. When IM calls your function, it will pass in arguments that describe the training environment. Check the script below to see how this works.\n",
"\n",
"The network itself is a pre-built version contained in the [Gluon Model Zoo](https://mxnet.incubator.apache.org/versions/master/api/python/gluon/model_zoo.html)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"!cat 'cifar10.py'"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Run the training script on IM\n",
"\n",
"The ```MXNet``` class allows us to run our training function as a distributed training job on IM infrastructure. We need to configure it with our training script, an IAM role, the number of training instances, and the training instance type. In this case we will run our training job on four p2.xlarge instances. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"m = MXNet(\"cifar10.py\", \n",
" role=role, \n",
" train_instance_count=4, \n",
" train_instance_type=\"p2.xlarge\",\n",
" hyperparameters={'batch_size': 128, \n",
" 'epochs': 50, \n",
" 'learning_rate': 0.1, \n",
" 'momentum': 0.9,\n",
" '_ps_verbose': 0})"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"After we've constructed our `MXNet` object, we can fit it using the data we uploaded to S3. IM makes sure our data is available in the local filesystem, so our training script can simply read the data from disk."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"m.fit(inputs)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Prediction\n",
"\n",
"After training, we use the MXNet object to create and deploy a hosted prediction endpoint. We can use the object returned by `deploy` to call the endpoint and perform inference on our sample image."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"predictor = m.deploy(min_instances=1, max_instances=1, instance_type='c4.xlarge')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### CIFAR10 sample images\n",
"\n",
"We'll use these CIFAR10 sample images to test the service:\n",
"\n",
"<img style=\"display: inline; height: 32px; margin: 0.25em\" src=\"images/airplane1.png\" />\n",
"<img style=\"display: inline; height: 32px; margin: 0.25em\" src=\"images/automobile1.png\" />\n",
"<img style=\"display: inline; height: 32px; margin: 0.25em\" src=\"images/bird1.png\" />\n",
"<img style=\"display: inline; height: 32px; margin: 0.25em\" src=\"images/cat1.png\" />\n",
"<img style=\"display: inline; height: 32px; margin: 0.25em\" src=\"images/deer1.png\" />\n",
"<img style=\"display: inline; height: 32px; margin: 0.25em\" src=\"images/dog1.png\" />\n",
"<img style=\"display: inline; height: 32px; margin: 0.25em\" src=\"images/frog1.png\" />\n",
"<img style=\"display: inline; height: 32px; margin: 0.25em\" src=\"images/horse1.png\" />\n",
"<img style=\"display: inline; height: 32px; margin: 0.25em\" src=\"images/ship1.png\" />\n",
"<img style=\"display: inline; height: 32px; margin: 0.25em\" src=\"images/truck1.png\" />\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# load the CIFAR10 samples, and convert them into format we can use with the prediction endpoint\n",
"from cifar10_utils import read_images\n",
"\n",
"filenames = ['images/airplane1.png',\n",
" 'images/automobile1.png',\n",
" 'images/bird1.png',\n",
" 'images/cat1.png',\n",
" 'images/deer1.png',\n",
" 'images/dog1.png',\n",
" 'images/frog1.png',\n",
" 'images/horse1.png',\n",
" 'images/ship1.png',\n",
" 'images/truck1.png']\n",
"\n",
"image_data = read_images(filenames)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The predictor runs inference on our input data and returns the predicted class label (as a float value, so we convert to int for display)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true,
"scrolled": true
},
"outputs": [],
"source": [
"for i, img in enumerate(image_data):\n",
" response = predictor.predict(img)\n",
" print('image {}: class: {}'.format(i, int(response)))\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Cleanup\n",
"\n",
"After you have finished with this example, remember to delete the prediction endpoint to release the instance(s) associated with it."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"m.delete_endpoint()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python [conda env:mxnet_p27]",
"language": "python",
"name": "conda-env-mxnet_p27-py"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.13"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Loading