Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] script for distributed ssd training #541

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions gluoncv/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .recordio.detection import RecordFileDetection
from .lst.detection import LstDetection
from .mixup.detection import MixupDetection
from .sampler import SplitSampler

datasets = {
'ade20k': ADE20KSegmentation,
Expand Down
33 changes: 33 additions & 0 deletions gluoncv/data/sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from mxnet import gluon
import random

__all__ = ['SplitSampler']

class SplitSampler(gluon.data.sampler.Sampler):
""" Split the dataset into `num_parts` parts and sample from the part with index `part_index`

Parameters
----------
length: int
Number of examples in the dataset
num_parts: int
Partition the data into multiple parts
part_index: int
The index of the part to read from
"""
def __init__(self, length, num_parts=1, part_index=0):
# Compute the length of each partition
self.part_len = length // num_parts
# Compute the start index for this partition
self.start = self.part_len * part_index
# Compute the end index for this partition
self.end = self.start + self.part_len

def __iter__(self):
# Extract examples between `start` and `end`, shuffle and return them.
indices = list(range(self.start, self.end))
random.shuffle(indices)
return iter(indices)

def __len__(self):
return self.part_len
65 changes: 65 additions & 0 deletions gluoncv/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,71 @@ def _as_list(arr):
return arr


class HybridSSDMultiBoxLoss(gluon.HybridBlock):
r"""Single-Shot Multibox Object Detection Loss.

.. note::

`HybridSSDMultiBoxLoss` is a `HybridBlock` version of `SSDMultiBoxLoss`. However,
there are two differences:

- It avoids cross device synchronization in `hybrid_forward()`, which may result in
better throughput.
- It additionally returns the number of positive targets, which should be used to
rescale gradients manually before `trainer.step()` is performed.

Parameters
----------
negative_mining_ratio : float, default is 3
Ratio of negative vs. positive samples.
rho : float, default is 1.0
Threshold for trimmed mean estimator. This is the smooth parameter for the
L1-L2 transition.
lambd : float, default is 1.0
Relative weight between classification and box regression loss.
The overall loss is computed as :math:`L = loss_{class} + \lambda \times loss_{loc}`.

Inputs:
- **cls_pred**: the prediction tensor.
- **box_pred**: the box prediction tensor.
- **cls_target**: the class target tensor.
- **box_target**: the box target tensor.

Outputs:
- **sum_loss**: overall class and box prediction loss.
- **cls_loss**: class prediction loss.
- **box_loss**: box prediction loss.
- **num_pos**: number of positive targets in the batch (scalar).
"""
def __init__(self, negative_mining_ratio=3, rho=1.0, lambd=1.0, **kwargs):
super(HybridSSDMultiBoxLoss, self).__init__(**kwargs)
self._negative_mining_ratio = max(0, negative_mining_ratio)
self._rho = rho
self._lambd = lambd

def hybrid_forward(self, F, cls_pred, box_pred, cls_target, box_target):
"""Compute loss in entire batch across devices."""
pos = cls_target > 0
num_pos = pos.sum()
pred = F.log_softmax(cls_pred, axis=-1)
cls_loss = -F.pick(pred, cls_target, axis=-1, keepdims=False)
rank = F.broadcast_mul(cls_loss, (pos - 1)).argsort(axis=1).argsort(axis=1)
hard_negative = F.broadcast_lesser(rank, (pos.sum(axis=1) * self._negative_mining_ratio).expand_dims(-1))
# mask out if not positive or negative
cls_loss = F.where((pos + hard_negative) > 0, cls_loss, F.zeros_like(cls_loss))
cls_loss = F.sum(cls_loss, axis=0, exclude=True) / 1

box_pred = _reshape_like(F, box_pred, box_target)
box_loss = F.abs(box_pred - box_target)
box_loss = F.where(box_loss > self._rho, box_loss - 0.5 * self._rho,
(0.5 / self._rho) * box_loss.square())
# box loss only apply to positive samples
box_loss = F.broadcast_mul(box_loss, pos.expand_dims(axis=-1))
box_loss = F.sum(box_loss, axis=0, exclude=True) / 1
sum_loss = cls_loss + self._lambd * box_loss
return sum_loss, cls_loss, box_loss, num_pos


class SSDMultiBoxLoss(gluon.Block):
r"""Single-Shot Multibox Object Detection Loss.

Expand Down
2 changes: 1 addition & 1 deletion gluoncv/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@
from .filesystem import makedirs
from .bbox import bbox_iou
from .block import recursive_visit, set_lr_mult, freeze_bn
from .lr_scheduler import LRScheduler
from .lr_scheduler import LRScheduler, DistLRScheduler
from .plot_history import TrainingHistory
from .export_helper import export_block
111 changes: 111 additions & 0 deletions gluoncv/utils/lr_scheduler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Popular Learning Rate Schedulers"""
# pylint: disable=missing-docstring
from __future__ import division
import warnings

from math import pi, cos
from mxnet import lr_scheduler
Expand Down Expand Up @@ -28,6 +29,10 @@ class LRScheduler(lr_scheduler.LRScheduler):

lr = warmup_lr

.. note::

Please consider `DistLRScheduler` for training with dist kvstore.

Parameters
----------
mode : str
Expand Down Expand Up @@ -106,3 +111,109 @@ def update(self, i, epoch):
(1 + cos(pi * (T - self.warmup_N) / (self.N - self.warmup_N))) / 2
else:
raise NotImplementedError

class DistLRScheduler(lr_scheduler.LRScheduler):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a way to make this PR and #353 compatible?

r"""Learning rate scheduler for distributed training with KVStore.

For mode='step', we multiply lr with `step_factor` at each epoch in `step`.

For mode='poly'::

lr = targetlr + (baselr - targetlr) * (1 - iter / maxiter) ^ power

For mode='cosine'::

lr = targetlr + (baselr - targetlr) * (1 + cos(pi * iter / maxiter)) / 2

If warmup_epochs > 0, a warmup stage will be inserted before the main lr scheduler.

For warmup_mode='linear'::

lr = warmup_lr + (baselr - warmup_lr) * iter / max_warmup_iter

For warmup_mode='constant'::

lr = warmup_lr

Parameters
----------
mode : str
Modes for learning rate scheduler.
Currently it supports 'step', 'poly' and 'cosine'.
base_lr : float
Base learning rate, i.e. the starting learning rate.
niters : int
Number of iterations in each epoch.
nepochs : int
Number of training epochs.
step : list
A list of epochs to decay the learning rate.
step_factor : float
Learning rate decay factor.
target_lr : float
Target learning rate for poly and cosine, as the ending learning rate.
power : float
Power of poly function.
warmup_epochs : int
Number of epochs for the warmup stage.
warmup_lr : float
The base learning rate for the warmup stage.
warmup_mode : str
Modes for the warmup stage.
Currently it supports 'linear' and 'constant'.
"""
def __init__(self, mode, base_lr, niters, nepochs,
step=(30, 60, 90), step_factor=0.1, target_lr=0, power=0.9,
warmup_epochs=0, warmup_lr=0, warmup_mode='linear'):
super(DistLRScheduler, self).__init__()
assert(mode in ['step', 'poly', 'cosine'])
assert(warmup_mode in ['linear', 'constant'])

self.mode = mode
self.base_lr = base_lr
self.learning_rate = self.base_lr
self.niters = niters

self.step = step
self.step_factor = step_factor
self.target_lr = target_lr
self.power = power
self.warmup_epochs = warmup_epochs
self.warmup_lr = warmup_lr
self.warmup_mode = warmup_mode

self.N = nepochs * niters
self.warmup_N = warmup_epochs * niters

def __call__(self, num_update):
self._update(num_update)
return self.learning_rate

def _update(self, T):
epoch = T // self.niters
if T > self.N:
warnings.warn("DistLRScheduler expects <= %d updates, but got num_update=%d. "
"This might be caused by extra data samples rolling over.")
return

if self.warmup_epochs > epoch:
# Warm-up Stage
if self.warmup_mode == 'linear':
self.learning_rate = self.warmup_lr + (self.base_lr - self.warmup_lr) * \
T / self.warmup_N
elif self.warmup_mode == 'constant':
self.learning_rate = self.warmup_lr
else:
raise NotImplementedError
else:
if self.mode == 'step':
count = sum([1 for s in self.step if s <= epoch])
self.learning_rate = self.base_lr * pow(self.step_factor, count)
elif self.mode == 'poly':
self.learning_rate = self.target_lr + (self.base_lr - self.target_lr) * \
pow(1 - (T - self.warmup_N) / (self.N - self.warmup_N), self.power)
elif self.mode == 'cosine':
self.learning_rate = self.target_lr + (self.base_lr - self.target_lr) * \
(1 + cos(pi * (T - self.warmup_N) / (self.N - self.warmup_N))) / 2
else:
raise NotImplementedError
Loading