A PyTorch implementation of Single Shot MultiBox Detector from the 2016 paper by Wei Liu, Dragomir Anguelov, Dumitru Erhan, Christian Szegedy, Scott Reed, Cheng-Yang, and Alexander C. Berg. The official and original Caffe code can be found here.
- Install PyTorch by selecting your environment on the website and running the appropriate command.
- Clone this repository.
- Note: We currently only support Python 3+.
- Then download the dataset by following the instructions below.
- We now support Visdom for real-time loss visualization during training!
- To use Visdom in the browser:
# First install Python server and client pip install visdom # Start the server (probably in a screen or tmux) python -m visdom.server
- Then (during training) navigate to http://localhost:8097/ (see the Train section below for training details).
- Note: For training, we currently only support VOC, but are adding COCO and hopefully ImageNet soon.
- UPDATE: We have switched from PIL Image support to cv2. The plan is to create a branch that uses PIL as well.
To make things easy, we provide a simple VOC dataset loader that enherits torch.utils.data.Dataset
making it fully compatible with the torchvision.datasets
API.
# specify a directory for dataset to be downloaded into, else default is ~/data/
sh data/scripts/VOC2007.sh # <directory>
# specify a directory for dataset to be downloaded into, else default is ~/data/
sh data/scripts/VOC2012.sh # <directory>
Ensure the following directory structure (as specified in VOCdevkit):
VOCdevkit/ % development kit
VOCdevkit/VOC2007/ImageSets % image sets
VOCdevkit/VOC2007/Annotations % annotation files
VOCdevkit/VOC2007/JPEGImages % images
VOCdevkit/VOC2007/SegmentationObject % segmentations by object
VOCdevkit/VOC2007/SegmentationClass % segmentations by class
- First download the fc-reduced VGG-16 PyTorch base network weights at: https://s3.amazonaws.com/amdegroot-models/vgg16_reducedfc.pth
- By default, we assume you have downloaded the file in the
ssd.pytorch/weights
dir:
mkdir weights
cd weights
wget https://s3.amazonaws.com/amdegroot-models/vgg16_reducedfc.pth
- To train SSD using the train script simply specify the parameters listed in
train.py
as a flag or manually change them.
python train.py
- Training Parameter Options:
parser = argparse.ArgumentParser(description='Single Shot MultiBox Detector Training')
parser.add_argument('--version', default='v2', help='conv11_2(v2) or pool6(v1) as last layer')
parser.add_argument('--basenet', default='vgg16_reducedfc.pth', help='pretrained base model')
parser.add_argument('--jaccard_threshold', default=0.5, type=float, help='Min Jaccard index for matching')
parser.add_argument('--batch_size', default=32, type=int, help='Batch size for training')
parser.add_argument('--num_workers', default=4, type=int, help='Number of workers used in dataloading')
parser.add_argument('--iterations', default=120000, type=int, help='Number of training epochs')
parser.add_argument('--cuda', default=True, type=bool, help='Use cuda to train model')
parser.add_argument('--lr', '--learning-rate', default=1e-3, type=float, help='initial learning rate')
parser.add_argument('--momentum', default=0.9, type=float, help='momentum')
parser.add_argument('--weight_decay', default=5e-4, type=float, help='Weight decay for SGD')
parser.add_argument('--gamma', default=0.1, type=float, help='Gamma update for SGD')
parser.add_argument('--log_iters', default=True, type=bool, help='Print the loss at each iteration')
parser.add_argument('--visdom', default=False, type=bool, help='Use visdom to for loss visualization')
parser.add_argument('--save_folder', default='weights/', help='Location to save checkpoint models')
args = parser.parse_args()
- Note:
- For training, an NVIDIA GPU is strongly recommended for speed.
- Currently we only support training on v2 (the newest version).
- For instructions on Visdom usage/installation, see the Installation section.
To evaluate a trained network:
python test.py
You can specify the parameters listed in the test.py
file by flagging them or manually changing them.
Original | Test (weiliu89 weights) | Train (w/o data aug) and Test* |
---|---|---|
77.2 % | 77.26 % | 58.12%* |
* note: Obtained with a constant learning rate of 1e-4 @15k iterations. With proper adjustment, we believe this should increase substantially even w/o data aug.
GTX 1060: ~45.45 FPS for detection on a single image
- We are trying to provide PyTorch
state_dicts
(dict of weight tensors) of the latest SSD model definitions trained on different datasets. - Currently, we provide the following PyTorch models:
- SSD300 v2 trained on VOC0712 (newest version)
- SSD300 v1 (original/old pool6 version) trained on VOC07
- Our goal is to reproduce this table from the original paper
- Make sure you have jupyter notebook installed.
- Two alternatives for installing jupyter notebook:
# make sure pip is upgraded
pip3 install --upgrade pip
# install jupyter notebook
pip install jupyter
# Run this inside ssd.pytorch
jupyter notebook
- Now navigate to
demo/demo.ipynb
at http://localhost:8888 (by default) and have at it!
- Works on CPU (may have to tweak
cv2.waitkey
for optimal fps) or on an NVIDIA GPU - This demo currently requires opencv2+ w/ python bindings and an onboard webcam
- You can change the default webcam in
demo/live.py
- You can change the default webcam in
- Install the imutils package to leverage multi-threading on CPU:
pip install imutils
- Running
python -m demo.live
opens the webcam and begins detecting!
We have accumulated the following to-do list, which you can expect to be done in the very near future
- In progress:
- Complete data augmentation (progress in augmentation branch)
- Produce a "from scratch" PyTorch mAP matching the original Caffe result
- Still to come:
- Train SSD300 with batch norm
- Add support for SSD512 training and testing
- Add support for COCO dataset
- Create a functional model definition for Sergey Zagoruyko's functional-zoo
- Wei Liu, et al. "SSD: Single Shot MultiBox Detector." ECCV2016.
- Original Implementation (CAFFE)
- A list of other great SSD ports that were sources of inspiration (especially the Chainer repo):