From 5cd5ac7356283bd908c725ad83e50084ebf270d9 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 28 Dec 2018 21:46:04 +0000 Subject: [PATCH] Add dist training script for ssd fix bugs update kvstore fix step size --- gluoncv/data/__init__.py | 1 + gluoncv/data/sampler.py | 33 +++ gluoncv/loss.py | 65 ++++++ gluoncv/utils/__init__.py | 2 +- gluoncv/utils/lr_scheduler.py | 111 +++++++++ scripts/detection/ssd/dist_train_ssd.py | 286 ++++++++++++++++++++++++ 6 files changed, 497 insertions(+), 1 deletion(-) create mode 100644 gluoncv/data/sampler.py create mode 100644 scripts/detection/ssd/dist_train_ssd.py diff --git a/gluoncv/data/__init__.py b/gluoncv/data/__init__.py index 163c664132..7df15c589e 100644 --- a/gluoncv/data/__init__.py +++ b/gluoncv/data/__init__.py @@ -17,6 +17,7 @@ from .recordio.detection import RecordFileDetection from .lst.detection import LstDetection from .mixup.detection import MixupDetection +from .sampler import SplitSampler datasets = { 'ade20k': ADE20KSegmentation, diff --git a/gluoncv/data/sampler.py b/gluoncv/data/sampler.py new file mode 100644 index 0000000000..8b5eb512b0 --- /dev/null +++ b/gluoncv/data/sampler.py @@ -0,0 +1,33 @@ +from mxnet import gluon +import random + +__all__ = ['SplitSampler'] + +class SplitSampler(gluon.data.sampler.Sampler): + """ Split the dataset into `num_parts` parts and sample from the part with index `part_index` + + Parameters + ---------- + length: int + Number of examples in the dataset + num_parts: int + Partition the data into multiple parts + part_index: int + The index of the part to read from + """ + def __init__(self, length, num_parts=1, part_index=0): + # Compute the length of each partition + self.part_len = length // num_parts + # Compute the start index for this partition + self.start = self.part_len * part_index + # Compute the end index for this partition + self.end = self.start + self.part_len + + def __iter__(self): + # Extract examples between `start` and `end`, shuffle and return them. + indices = list(range(self.start, self.end)) + random.shuffle(indices) + return iter(indices) + + def __len__(self): + return self.part_len diff --git a/gluoncv/loss.py b/gluoncv/loss.py index 900ca5425b..ffd4bd466d 100644 --- a/gluoncv/loss.py +++ b/gluoncv/loss.py @@ -95,6 +95,71 @@ def _as_list(arr): return arr +class HybridSSDMultiBoxLoss(gluon.HybridBlock): + r"""Single-Shot Multibox Object Detection Loss. + + .. note:: + + `HybridSSDMultiBoxLoss` is a `HybridBlock` version of `SSDMultiBoxLoss`. However, + there are two differences: + + - It avoids cross device synchronization in `hybrid_forward()`, which may result in + better throughput. + - It additionally returns the number of positive targets, which should be used to + rescale gradients manually before `trainer.step()` is performed. + + Parameters + ---------- + negative_mining_ratio : float, default is 3 + Ratio of negative vs. positive samples. + rho : float, default is 1.0 + Threshold for trimmed mean estimator. This is the smooth parameter for the + L1-L2 transition. + lambd : float, default is 1.0 + Relative weight between classification and box regression loss. + The overall loss is computed as :math:`L = loss_{class} + \lambda \times loss_{loc}`. + + Inputs: + - **cls_pred**: the prediction tensor. + - **box_pred**: the box prediction tensor. + - **cls_target**: the class target tensor. + - **box_target**: the box target tensor. + + Outputs: + - **sum_loss**: overall class and box prediction loss. + - **cls_loss**: class prediction loss. + - **box_loss**: box prediction loss. + - **num_pos**: number of positive targets in the batch (scalar). + """ + def __init__(self, negative_mining_ratio=3, rho=1.0, lambd=1.0, **kwargs): + super(HybridSSDMultiBoxLoss, self).__init__(**kwargs) + self._negative_mining_ratio = max(0, negative_mining_ratio) + self._rho = rho + self._lambd = lambd + + def hybrid_forward(self, F, cls_pred, box_pred, cls_target, box_target): + """Compute loss in entire batch across devices.""" + pos = cls_target > 0 + num_pos = pos.sum() + pred = F.log_softmax(cls_pred, axis=-1) + cls_loss = -F.pick(pred, cls_target, axis=-1, keepdims=False) + rank = F.broadcast_mul(cls_loss, (pos - 1)).argsort(axis=1).argsort(axis=1) + hard_negative = F.broadcast_lesser(rank, (pos.sum(axis=1) * self._negative_mining_ratio).expand_dims(-1)) + # mask out if not positive or negative + cls_loss = F.where((pos + hard_negative) > 0, cls_loss, F.zeros_like(cls_loss)) + cls_loss = F.sum(cls_loss, axis=0, exclude=True) / 1 + + box_pred = _reshape_like(F, box_pred, box_target) + box_loss = F.abs(box_pred - box_target) + box_loss = F.where(box_loss > self._rho, box_loss - 0.5 * self._rho, + (0.5 / self._rho) * box_loss.square()) + # box loss only apply to positive samples + box_loss = F.broadcast_mul(box_loss, pos.expand_dims(axis=-1)) + box_loss = F.sum(box_loss, axis=0, exclude=True) / 1 + sum_loss = cls_loss + self._lambd * box_loss + return sum_loss, cls_loss, box_loss, num_pos + + class SSDMultiBoxLoss(gluon.Block): r"""Single-Shot Multibox Object Detection Loss. diff --git a/gluoncv/utils/__init__.py b/gluoncv/utils/__init__.py index 8893aeb069..9025e5f095 100644 --- a/gluoncv/utils/__init__.py +++ b/gluoncv/utils/__init__.py @@ -11,6 +11,6 @@ from .filesystem import makedirs from .bbox import bbox_iou from .block import recursive_visit, set_lr_mult, freeze_bn -from .lr_scheduler import LRScheduler +from .lr_scheduler import LRScheduler, DistLRScheduler from .plot_history import TrainingHistory from .export_helper import export_block diff --git a/gluoncv/utils/lr_scheduler.py b/gluoncv/utils/lr_scheduler.py index eb52bee629..3075ed5cb7 100644 --- a/gluoncv/utils/lr_scheduler.py +++ b/gluoncv/utils/lr_scheduler.py @@ -1,6 +1,7 @@ """Popular Learning Rate Schedulers""" # pylint: disable=missing-docstring from __future__ import division +import warnings from math import pi, cos from mxnet import lr_scheduler @@ -28,6 +29,10 @@ class LRScheduler(lr_scheduler.LRScheduler): lr = warmup_lr + .. note:: + + Please consider `DistLRScheduler` for training with dist kvstore. + Parameters ---------- mode : str @@ -106,3 +111,109 @@ def update(self, i, epoch): (1 + cos(pi * (T - self.warmup_N) / (self.N - self.warmup_N))) / 2 else: raise NotImplementedError + +class DistLRScheduler(lr_scheduler.LRScheduler): + r"""Learning rate scheduler for distributed training with KVStore. + + For mode='step', we multiply lr with `step_factor` at each epoch in `step`. + + For mode='poly':: + + lr = targetlr + (baselr - targetlr) * (1 - iter / maxiter) ^ power + + For mode='cosine':: + + lr = targetlr + (baselr - targetlr) * (1 + cos(pi * iter / maxiter)) / 2 + + If warmup_epochs > 0, a warmup stage will be inserted before the main lr scheduler. + + For warmup_mode='linear':: + + lr = warmup_lr + (baselr - warmup_lr) * iter / max_warmup_iter + + For warmup_mode='constant':: + + lr = warmup_lr + + Parameters + ---------- + mode : str + Modes for learning rate scheduler. + Currently it supports 'step', 'poly' and 'cosine'. + base_lr : float + Base learning rate, i.e. the starting learning rate. + niters : int + Number of iterations in each epoch. + nepochs : int + Number of training epochs. + step : list + A list of epochs to decay the learning rate. + step_factor : float + Learning rate decay factor. + target_lr : float + Target learning rate for poly and cosine, as the ending learning rate. + power : float + Power of poly function. + warmup_epochs : int + Number of epochs for the warmup stage. + warmup_lr : float + The base learning rate for the warmup stage. + warmup_mode : str + Modes for the warmup stage. + Currently it supports 'linear' and 'constant'. + """ + def __init__(self, mode, base_lr, niters, nepochs, + step=(30, 60, 90), step_factor=0.1, target_lr=0, power=0.9, + warmup_epochs=0, warmup_lr=0, warmup_mode='linear'): + super(DistLRScheduler, self).__init__() + assert(mode in ['step', 'poly', 'cosine']) + assert(warmup_mode in ['linear', 'constant']) + + self.mode = mode + self.base_lr = base_lr + self.learning_rate = self.base_lr + self.niters = niters + + self.step = step + self.step_factor = step_factor + self.target_lr = target_lr + self.power = power + self.warmup_epochs = warmup_epochs + self.warmup_lr = warmup_lr + self.warmup_mode = warmup_mode + + self.N = nepochs * niters + self.warmup_N = warmup_epochs * niters + + def __call__(self, num_update): + self._update(num_update) + return self.learning_rate + + def _update(self, T): + epoch = T // self.niters + if T > self.N: + warnings.warn("DistLRScheduler expects <= %d updates, but got num_update=%d. " + "This might be caused by extra data samples rolling over.") + return + + if self.warmup_epochs > epoch: + # Warm-up Stage + if self.warmup_mode == 'linear': + self.learning_rate = self.warmup_lr + (self.base_lr - self.warmup_lr) * \ + T / self.warmup_N + elif self.warmup_mode == 'constant': + self.learning_rate = self.warmup_lr + else: + raise NotImplementedError + else: + if self.mode == 'step': + count = sum([1 for s in self.step if s <= epoch]) + self.learning_rate = self.base_lr * pow(self.step_factor, count) + elif self.mode == 'poly': + self.learning_rate = self.target_lr + (self.base_lr - self.target_lr) * \ + pow(1 - (T - self.warmup_N) / (self.N - self.warmup_N), self.power) + elif self.mode == 'cosine': + self.learning_rate = self.target_lr + (self.base_lr - self.target_lr) * \ + (1 + cos(pi * (T - self.warmup_N) / (self.N - self.warmup_N))) / 2 + else: + raise NotImplementedError diff --git a/scripts/detection/ssd/dist_train_ssd.py b/scripts/detection/ssd/dist_train_ssd.py new file mode 100644 index 0000000000..090b187434 --- /dev/null +++ b/scripts/detection/ssd/dist_train_ssd.py @@ -0,0 +1,286 @@ +"""Distributed SSD Training""" +import argparse +import os +import logging +import time +import numpy as np +import mxnet as mx +from mxnet import nd +from mxnet import gluon +from mxnet import autograd +import gluoncv as gcv +from gluoncv import data as gdata +from gluoncv import utils as gutils +from gluoncv.model_zoo import get_model +from gluoncv.data.batchify import Tuple, Stack, Pad +from gluoncv.data.transforms.presets.ssd import SSDDefaultTrainTransform +from gluoncv.data.transforms.presets.ssd import SSDDefaultValTransform +from gluoncv.data.sampler import SplitSampler +from gluoncv.utils.metrics.voc_detection import VOC07MApMetric +from gluoncv.utils.metrics.coco_detection import COCODetectionMetric +from gluoncv.utils.metrics.accuracy import Accuracy +from gluoncv.utils.lr_scheduler import DistLRScheduler + +def parse_args(): + parser = argparse.ArgumentParser(description='Train SSD networks.') + parser.add_argument('--network', type=str, default='vgg16_atrous', + help="Base network name which serves as feature extraction base.") + parser.add_argument('--data-shape', type=int, default=300, + help="Input data shape, use 300, 512.") + parser.add_argument('--batch-size', type=int, default=32, + help='Training mini-batch size') + parser.add_argument('--dataset', type=str, default='voc', + help='Training dataset. Now support voc.') + parser.add_argument('--num-workers', '-j', dest='num_workers', type=int, + default=4, help='Number of data workers, you can use larger ' + 'number to accelerate data loading, if you CPU and GPUs are powerful.') + parser.add_argument('--gpus', type=str, default='0', + help='Training with GPUs, you can specify 1,3 for example.') + parser.add_argument('--kvstore', type=str, default='device', + help='KVStore type. Supports dist_sync_device, device') + parser.add_argument('--epochs', type=int, default=240, + help='Training epochs.') + parser.add_argument('--resume', type=str, default='', + help='Resume from previously saved parameters if not None. ' + 'For example, you can resume from ./ssd_xxx_0123.params') + parser.add_argument('--start-epoch', type=int, default=0, + help='Starting epoch for resuming, default is 0 for new training.' + 'You can specify it to 100 for example to start from 100 epoch.') + parser.add_argument('--lr', type=float, default=0.001, + help='Learning rate, default is 0.001') + parser.add_argument('--lr-decay', type=float, default=0.1, + help='decay rate of learning rate. default is 0.1.') + parser.add_argument('--lr-decay-epoch', type=str, default='160,200', + help='epoches at which learning rate decays. default is 160,200.') + parser.add_argument('--momentum', type=float, default=0.9, + help='SGD momentum, default is 0.9') + parser.add_argument('--wd', type=float, default=0.0005, + help='Weight decay, default is 5e-4') + parser.add_argument('--log-interval', type=int, default=100, + help='Logging mini-batch interval. Default is 100.') + parser.add_argument('--save-prefix', type=str, default='', + help='Saving parameter prefix') + parser.add_argument('--save-interval', type=int, default=10, + help='Saving parameters epoch interval, best model will always be saved.') + parser.add_argument('--val-interval', type=int, default=1, + help='Epoch interval for validation, increase the number will reduce the ' + 'training time if validation is slow.') + parser.add_argument('--seed', type=int, default=233, + help='Random seed to be fixed.') + args = parser.parse_args() + return args + +def get_dataset(dataset, args): + if dataset.lower() == 'voc': + train_dataset = gdata.VOCDetection( + splits=[(2007, 'trainval'), (2012, 'trainval')]) + val_dataset = gdata.VOCDetection( + splits=[(2007, 'test')]) + val_metric = VOC07MApMetric(iou_thresh=0.5, class_names=val_dataset.classes) + num_training_samples = 16551 + elif dataset.lower() == 'coco': + train_dataset = gdata.COCODetection(splits='instances_train2017') + val_dataset = gdata.COCODetection(splits='instances_val2017', skip_empty=False) + val_metric = COCODetectionMetric( + val_dataset, args.save_prefix + '_eval', cleanup=True, + data_shape=(args.data_shape, args.data_shape)) + num_training_samples = 117266 + # coco validation is slow, consider increase the validation interval + if args.val_interval == 1: + args.val_interval = 10 + else: + raise NotImplementedError('Dataset: {} not implemented.'.format(dataset)) + return train_dataset, val_dataset, val_metric, num_training_samples + +def get_dataloader(net, train_dataset, val_dataset, data_shape, batch_size, + num_workers, store, num_training_samples): + """Get dataloader.""" + width, height = data_shape, data_shape + # use fake data to generate fixed anchors for target generation + with autograd.train_mode(): + _, _, anchors = net(mx.nd.zeros((1, 3, height, width))) + batchify_fn = Tuple(Stack(), Stack(), Stack()) # stack image, cls_targets, box_targets + # use split sampler to access a subset of the dataset + train_loader = gluon.data.DataLoader( + train_dataset.transform(SSDDefaultTrainTransform(width, height, anchors)), + batch_size, False, batchify_fn=batchify_fn, last_batch='rollover', num_workers=num_workers, + sampler=SplitSampler(num_training_samples, num_parts=store.num_workers, part_index=store.rank)) + val_batchify_fn = Tuple(Stack(), Pad(pad_val=-1)) + val_loader = gluon.data.DataLoader( + val_dataset.transform(SSDDefaultValTransform(width, height)), + batch_size, False, batchify_fn=val_batchify_fn, last_batch='keep', num_workers=num_workers) + return train_loader, val_loader + +def save_params(net, best_map, current_map, epoch, save_interval, prefix): + current_map = float(current_map) + if current_map > best_map[0]: + best_map[0] = current_map + net.save_params('{:s}_best.params'.format(prefix, epoch, current_map)) + with open(prefix+'_best_map.log', 'a') as f: + f.write('{:04d}:\t{:.4f}\n'.format(epoch, current_map)) + if save_interval and epoch % save_interval == 0: + net.save_params('{:s}_{:04d}_{:.4f}.params'.format(prefix, epoch, current_map)) + +def validate(net, val_data, ctx, eval_metric): + """Test on validation dataset.""" + eval_metric.reset() + # set nms threshold and topk constraint + net.set_nms(nms_thresh=0.45, nms_topk=400) + net.hybridize(static_alloc=True, static_shape=True) + for batch in val_data: + data = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0, even_split=False) + label = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0, even_split=False) + det_bboxes = [] + det_ids = [] + det_scores = [] + gt_bboxes = [] + gt_ids = [] + gt_difficults = [] + for x, y in zip(data, label): + # get prediction results + ids, scores, bboxes = net(x) + det_ids.append(ids) + det_scores.append(scores) + # clip to image size + det_bboxes.append(bboxes.clip(0, batch[0].shape[2])) + # split ground truths + gt_ids.append(y.slice_axis(axis=-1, begin=4, end=5)) + gt_bboxes.append(y.slice_axis(axis=-1, begin=0, end=4)) + gt_difficults.append(y.slice_axis(axis=-1, begin=5, end=6) if y.shape[-1] > 5 else None) + + # update metric + eval_metric.update(det_bboxes, det_ids, det_scores, gt_bboxes, gt_ids, gt_difficults) + return eval_metric.get() + +def _rescale_grad(params, scale): + for param in params.values(): + if param.grad_req != 'null': + for grad in param.list_grad(): + grad *= scale + +def train(net, train_data, val_data, eval_metric, ctx, args, num_training_samples, store): + """Training pipeline""" + net.collect_params().reset_ctx(ctx) + trainer = gluon.Trainer( + net.collect_params(), 'sgd', + {'learning_rate': args.lr, 'wd': args.wd, 'momentum': args.momentum}) + + # lr decay policy + lr_decay = float(args.lr_decay) + lr_steps = sorted([float(ls) for ls in args.lr_decay_epoch.split(',') if ls.strip()]) + + mbox_loss = gcv.loss.HybridSSDMultiBoxLoss() + ce_metric = mx.metric.Loss('CrossEntropy') + smoothl1_metric = mx.metric.Loss('SmoothL1') + + # set up logger + logging.basicConfig() + logger = logging.getLogger() + logger.setLevel(logging.INFO) + log_file_path = args.save_prefix + '_train.log' + log_dir = os.path.dirname(log_file_path) + if log_dir and not os.path.exists(log_dir): + os.makedirs(log_dir) + fh = logging.FileHandler(log_file_path) + logger.addHandler(fh) + logger.info(args) + logger.info('Start training from [Epoch {}]'.format(args.start_epoch)) + best_map = [0] + net.collect_params().reset_ctx(ctx) + + # learning rate scheduler for (dist) kvstore + lr_scheduler = DistLRScheduler(mode='step', base_lr=args.lr, + niters=num_training_samples//(args.batch_size*store.num_workers), + nepochs=args.epochs, step=lr_steps, step_factor=lr_decay, power=2, + warmup_epochs=0) + trainer = gluon.Trainer( + net.collect_params(), 'sgd', + {'learning_rate': args.lr, 'wd': args.wd, 'momentum': args.momentum, 'lr_scheduler': lr_scheduler}, + kvstore=store) + + for epoch in range(args.start_epoch, args.epochs): + ce_metric.reset() + smoothl1_metric.reset() + tic = time.time() + btic = time.time() + net.hybridize(static_alloc=True, static_shape=True) + mbox_loss.hybridize(static_alloc=True, static_shape=True) + for i, batch in enumerate(train_data): + batch_size = batch[0].shape[0] + data = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0) + cls_targets = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0) + box_targets = gluon.utils.split_and_load(batch[2], ctx_list=ctx, batch_axis=0) + + sum_losses = [] + cls_losses = [] + box_losses = [] + num_poses = [] + for x, cls_target, box_target in zip(data, cls_targets, box_targets): + with autograd.record(): + cls_pred, box_pred, _ = net(x) + sum_loss, cls_loss, box_loss, num_pos = mbox_loss( + cls_pred, box_pred, cls_target, box_target) + sum_losses.append(sum_loss) + cls_losses.append(cls_loss) + box_losses.append(box_loss) + num_poses.append(num_pos) + autograd.backward(sum_losses) + + num_positives = sum([pos.asscalar() for pos in num_poses]) + # Normalized the gradient by the number of positive samples in the batch + _rescale_grad(net.collect_params(), 1.0/num_positives) + trainer.step(store.num_workers) + ce_metric.update(0, [l * batch_size / num_positives for l in cls_losses]) + smoothl1_metric.update(0, [l * batch_size / num_positives for l in box_losses]) + if args.log_interval and not (i + 1) % args.log_interval: + name1, loss1 = ce_metric.get() + name2, loss2 = smoothl1_metric.get() + logger.info('[Epoch {}][Batch {}], Speed: {:.3f} samples/sec, {}={:.3f}, {}={:.3f}'.format( + epoch, i, batch_size/(time.time()-btic), name1, loss1, name2, loss2)) + btic = time.time() + + name1, loss1 = ce_metric.get() + name2, loss2 = smoothl1_metric.get() + logger.info('[Epoch {}] Training cost: {:.3f}, {}={:.3f}, {}={:.3f}'.format( + epoch, (time.time()-tic), name1, loss1, name2, loss2)) + if (epoch % args.val_interval == 0) or (args.save_interval and epoch % args.save_interval == 0): + # consider reduce the frequency of validation to save time + map_name, mean_ap = validate(net, val_data, ctx, eval_metric) + val_msg = '\n'.join(['{}={}'.format(k, v) for k, v in zip(map_name, mean_ap)]) + logger.info('[Epoch {}] Validation: \n{}'.format(epoch, val_msg)) + current_map = float(mean_ap[-1]) + else: + current_map = 0. + save_params(net, best_map, current_map, epoch, args.save_interval, args.save_prefix) + +if __name__ == '__main__': + args = parse_args() + # fix seed for mxnet, numpy and python builtin random generator. + gutils.random.seed(args.seed) + + # kvstore + store = mx.kv.create(args.kvstore) + + # training contexts + ctx = [mx.gpu(int(i)) for i in args.gpus.split(',') if i.strip()] + ctx = ctx if ctx else [mx.cpu()] + + # network + net_name = '_'.join(('ssd', str(args.data_shape), args.network, args.dataset)) + args.save_prefix += net_name + net = get_model(net_name, pretrained_base=True) + if args.resume.strip(): + net.load_parameters(args.resume.strip()) + else: + for param in net.collect_params().values(): + if param._data is not None: + continue + param.initialize() + + # training data + train_dataset, val_dataset, eval_metric, num_training_samples = get_dataset(args.dataset, args) + train_data, val_data = get_dataloader(net, train_dataset, val_dataset, + args.data_shape, args.batch_size, + args.num_workers, store, num_training_samples) + # training + train(net, train_data, val_data, eval_metric, ctx, args, num_training_samples, store)