Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

segmentation tutorial #5

Merged
merged 2 commits into from
Apr 10, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
__pycache__/
*.py[cod]
*$py.class
*.swp
.DS_Store

# C extensions
*.so
Expand Down
3 changes: 3 additions & 0 deletions docs/api/python/datasets.rst
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
Vision Datasets
===============
.. automodule:: gluonvision.data
.. currentmodule:: gluonvision.data

Popular datasets for vision tasks are provided in gluonvision.
By default, we require all datasets reside in ~/.mxnet/datasets/ in order to have
frustration-free user experience and less path-works.
Expand Down
36 changes: 29 additions & 7 deletions docs/api/python/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,37 +27,65 @@ Dilated Network

We apply dilattion strategy to pre-trained ResNet models (with stride of 8). Please see :class:`gluonvision.model_zoo.SegBaseModel` for how to use it.

:hidden:`Dilated_ResNetV2`
~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: Dilated_ResNetV2
:members:

:hidden:`DilatedBasicBlockV2`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: DilatedBasicBlockV2
:members:

:hidden:`DilatedBottleneckV2`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: DilatedBottleneckV2
:members:

:hidden:`get_dilated_resnet`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: get_dilated_resnet


:hidden:`dilated_resnet18`
~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: dilated_resnet18


:hidden:`dilated_resnet34`
~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: dilated_resnet34


:hidden:`dilated_resnet50`
~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: dilated_resnet50


:hidden:`dilated_resnet101`
~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: dilated_resnet101


:hidden:`dilated_resnet152`
~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: dilated_resnet152


Object Detection
----------------

:hidden:`SSD`
~~~~~~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~

.. autoclass:: SSD
:members:
Expand All @@ -77,12 +105,6 @@ Semantic Segmentation
.. autoclass:: FCN
:members:

:hidden:`PSPNet`
~~~~~~~~~~~~~~~~

.. autoclass:: PSPNet
:members:


Common Components
-----------------
Expand Down
226 changes: 165 additions & 61 deletions docs/experiments/semanticseg.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,173 @@ _____________________
This is a semantic segmentation tutorial using Gluon Vison, a step-by-step example.
The readers should have basic knowledge of deep learning and should be familiar with Gluon API.
New users may first go through Gluon tutorials
`Deep Learning - The Straight Dope<http://gluon.mxnet.io/>`_.
`Deep Learning - The Straight Dope <http://gluon.mxnet.io/>`_.

Fully Convolutional Network
---------------------------

.. image:: https://cdn-images-1.medium.com/max/800/1*wRkj6lsQ5ckExB5BoYkrZg.png
:width: 70%
:align: center

(figure redit to `Long et al. <https://arxiv.org/pdf/1411.4038.pdf>`_ )

State-of-the-art approaches of semantic segmentation are typically based on
Fully Convolutional Network (FCN) [Long15]_ .
The key idea of a fully convolutional network is that it is "fully convolutional",
which means it does have any fully connected layers. Therefore, the network can
accept arbitrary input size and make dense per-pixel predictions.
Base/Encoder network is typically pre-trained on ImageNet, because the features
learned from diverse set of images contain rich contextual information, which
can be beneficial for semantic segmentation.


Model Dilation
~~~~~~~~~~~~~~

The adaption of base network pre-trained on ImageNet leads to loss spatial resolution,
because these networks are originally designed for classification task.
Following recent works in semantic segmentation, we apply dilation strategy to the
stage 3 and stage 4 of the pre-trained networks, which produces stride of 8
featuremaps (models are provided in :class:`gluonvision.model_zoo.Dilated_ResNetV2`).
Visualization of dilated/atrous convoution:

.. image:: https://raw.githubusercontent.com/vdumoulin/conv_arithmetic/master/gif/dilation.gif
:width: 40%
:align: center

(figure credit to `conv_arithmetic <https://github.com/vdumoulin/conv_arithmetic>`_ )

For example, loading a dilated ResNet50 is simply::

pretrained_net = gluonvision.model_zoo.dilated_resnet50(pretrained=True)

For convenience, we provide a base model for semantic segmentation, which automatically
load the pre-trained dilated ResNet :class:`gluonvision.model_zoo.SegBaseModel`, which can
be easily inherited and used.

FCN Block
~~~~~~~~~

We build a fully convolutional "head" on top of the basenetwork (FCN model is provided
in :class:`gluonvision.model_zoo.FCN`)::

class _FCNHead(HybridBlock):
def __init__(self, nclass, norm_layer):
super(_FCNHead, self).__init__()
with self.name_scope():
self.block = nn.HybridSequential(prefix='')
self.block.add(norm_layer(in_channels=2048))
self.block.add(nn.Activation('relu'))
self.block.add(nn.Conv2D(in_channels=2048, channels=512,
kernel_size=3, padding=1))
self.block.add(norm_layer(in_channels=512))
self.block.add(nn.Activation('relu'))
self.block.add(nn.Dropout(0.1))
self.block.add(nn.Conv2D(in_channels=512, channels=nclass,
kernel_size=1))

def hybrid_forward(self, F, x):
return self.block(x)

class FCN(SegBaseModel):
def __init__(self, nclass, backbone='resnet50', norm_layer=nn.BatchNorm):
super(FCN, self).__init__(backbone, norm_layer)
self._prefix = ''
with self.name_scope():
self.head = _FCNHead(nclass, norm_layer=norm_layer)
self.head.initialize(init=init.Xavier())

def forward(self, x):
_, _, H, W = x.shape
x = self.pretrained(x)
x = self.head(x)
x = F.contrib.BilinearResize2D(x, height=H, width=W)
return x

Dataset and Data Augmentation
-----------------------------

We provide semantic segmentation datasets in :class:`gluonvision.data`.
For example, we can easily get the Pascal VOC 2012 dataset::

train_set = gluonvision.data.VOCSegmentationDataset(root)

We follow the standard data augmentation routine to transform the input image
and the ground truth label map synchronously. (Note that "nearest"
mode upsample are applied to the label maps to avoid messing up the boundaries.)
We first randomly scale the input image from 0.5 to 2.0 times, then rotate
the image from -10 to 10 degrees, and crop the image with padding if needed.

.. todo::

add a gif showing the autmentation

Benchmarks and Training
_______________________

Test Pre-trained Model
~~~~~~~~~~~~~~~~~~~~~~

.. code-block:: python

# load pre-trained model
model = FCN(nclass=22, backbone='resnet101')
model.load_params('fcn101.params')
model.collect_params().reset_ctx(ctx)

# read image and normalize the data
transform = Compose([
ToTensor(ctx=ctx),
Normalize(mean=[.485, .456, .406], std=[.229, .224, .225], ctx=ctx)])

def load_image(path, transform, ctx):
image = Image.open(path).convert('RGB')
image = transform(image)
image = image.expand_dims(0).as_in_context(ctx)
return image

image = load_image('example.jpg', transform, ctx)

# make prediction using single scale
output = model(image)
predict = F.squeeze(F.argmax(output, 1)).asnumpy()

# add color pallete for visualization
mask = get_mask(predict, 'pascal_voc')
mask.save('output.png')

Please see the demo.py for more evaluation options.

.. image:: ../../scripts/segmentation/examples/1.jpg
:width: 45%

.. image:: ../../scripts/segmentation/examples/1.png
:width: 45%

.. image:: ../../scripts/segmentation/examples/4.jpg
:width: 45%

.. image:: ../../scripts/segmentation/examples/4.png
:width: 45%

.. image:: ../../scripts/segmentation/examples/5.jpg
:width: 45%

.. image:: ../../scripts/segmentation/examples/5.png
:width: 45%

.. image:: ../../scripts/segmentation/examples/6.jpg
:width: 45%

.. image:: ../../scripts/segmentation/examples/6.png
:width: 45%


Benchmarks and Training
_______________________

- Checkout the training scripts for reproducing the experiments, and see the detail running
instructions in the `README <https://github.com/dmlc/gluon-vision/tree/master/scripts/segmentation>`_ .

- Table of pre-trained models and its performance (models :math:`^\ast` denotes pre-trained on COCO):

.. role:: raw-html(raw)
Expand All @@ -31,8 +188,6 @@ Test Pre-trained Model
+------------------------+------------+-----------+-----------+-----------+-----------+----------------------------------------------------------------------------------------------+
| FCN | ResNet101 | PASCAL12 | stride 8 | N/A | | :raw-html:`<a href="javascript:toggleblock('cmd_fcn_101')" class="toggleblock">cmd</a>` |
+------------------------+------------+-----------+-----------+-----------+-----------+----------------------------------------------------------------------------------------------+
| PSPNet | ResNet50 | PASCAL12 | w/o aux | N/A | | :raw-html:`<a href="javascript:toggleblock('cmd_psp_50')" class="toggleblock">cmd</a>` |
+------------------------+------------+-----------+-----------+-----------+-----------+----------------------------------------------------------------------------------------------+

.. _70.9: http://host.robots.ox.ac.uk:8080/anonymous/FR9APO.html

Expand Down Expand Up @@ -75,60 +230,9 @@ Test Pre-trained Model
CUDA_VISIBLE_DEVICES=0,1,2,3 python main.py --dataset pascal_voc --model pspnet --backbone resnet101 --lr 0.0001 --syncbn --checkname mycheckpoint --resume runs/pascal_aug/fcn/mycheckpoint/checkpoint.params
</code>

References
----------

Train Your Own Model
~~~~~~~~~~~~~~~~~~~~

- Prepare PASCAL VOC Dataset and Augmented Dataset::

cd examples/datasets/
python setup_pascal_voc.py
python setup_pascal_aug.py

- Training command example::

# First training on augmented set
CUDA_VISIBLE_DEVICES=0,1,2,3 python main.py --dataset pascal_aug --model fcn --backbone resnet50 --lr 0.001 --checkname mycheckpoint
# Finetuning on original set
CUDA_VISIBLE_DEVICES=0,1,2,3 python main.py --dataset pascal_voc --model fcn --backbone resnet50 --lr 0.0001 --checkname mycheckpoint --resume runs/pascal_aug/fcn/mycheckpoint/checkpoint.params

For more training commands, please see the ``Commands`` in the pre-trained Table_.

- Detail training options::

-h, --help show this help message and exit
--model MODEL model name (default: fcn)
--backbone BACKBONE backbone name (default: resnet50)
--dataset DATASET dataset name (default: pascal)
--nclass NCLASS nclass for pre-trained model (default: None)
--workers N dataloader threads
--data-folder training dataset folder (default: $(HOME)/data/)
--epochs N number of epochs to train (default: 50)
--start_epoch N start epochs (default:0)
--batch-size N input batch size for training (default: 16)
--test-batch-size N input batch size for testing (default: 32)
--lr LR learning rate (default: 1e-3)
--momentum M momentum (default: 0.9)
--weight-decay M w-decay (default: 1e-4)
--kvstore KVSTORE kvstore to use for trainer/module.
--no-cuda disables CUDA training
--ngpus NGPUS number of GPUs (default: 4)
--seed S random seed (default: 1)
--resume RESUME put the path to resuming file if needed
--checkname set the checkpoint name
--eval evaluating mIoU
--test test a set of images and save the prediction
--syncbn using Synchronized Cross-GPU BatchNorm

Extending the Software
~~~~~~~~~~~~~~~~~~~~~~

- Write your own Dataloader ``mydataset.py`` to ``gluonvision/datasets/`` folder

- Write your own Model ``mymodel.py`` to ``gluonvision/models/`` folder

- Run the program:

.. code:: python

python main.py --dataset mydataset --model mymodel --nclass 10 ...
.. [Long15] Long, Jonathan, Evan Shelhamer, and Trevor Darrell. \
"Fully convolutional networks for semantic segmentation." \
Proceedings of the IEEE conference on computer vision and pattern recognition. 2015.
12 changes: 10 additions & 2 deletions docs/get_started/install.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
# Installation Guide

## Pip
### Install MXNet

To run the tutorials, a recent version of MXNet is required. The easiest way is to install the nightly build MXNet through ``pip``. E.g.:

## From Source
pip install mxnet --pre --user


### Install GluonVision

# clone the repo and install from source
git clone https://github.com/dmlc/gluon-vision
python setup.py install
2 changes: 0 additions & 2 deletions gluonvision/utils/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ class ModelDataParallel(object):
>>> y = net(x)
"""
def __init__(self, module, ctx, sync=False):
#super(ModelDataParallel, self).__init__()
self.ctx = ctx
module.collect_params().reset_ctx(ctx=ctx)
self.module = module
Expand Down Expand Up @@ -89,7 +88,6 @@ class CriterionDataParallel(object):
>>> losses = criterion(y, t)
"""
def __init__(self, module, sync=False):
#super(CriterionDataParallel, self).__init__()
self.module = module
self.sync = sync

Expand Down
Loading