Skip to content

Commit

Permalink
fix seg training with batchnorm.
Browse files Browse the repository at this point in the history
  • Loading branch information
donnyyou committed Jul 27, 2019
1 parent b47afa0 commit 7dfb954
Show file tree
Hide file tree
Showing 36 changed files with 142 additions and 184 deletions.
27 changes: 18 additions & 9 deletions configs/seg/cityscapes/fs_deeplabv3_cityscapes_seg.conf
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@
"backbone": "deepbase_resnet101_dilated8",
"multi_grid": [1, 1, 1],
"model_name": "deeplabv3",
"norm_type": "sync_batchnorm",
"norm_type": "batchnorm",
"stride": 8,
"checkpoints_name": "fs_deeplabv3_cityscapes_seg",
"checkpoints_dir": "./checkpoints/seg/cityscapes"
Expand Down Expand Up @@ -122,17 +122,26 @@
"max_iters": 40000
},
"loss": {
"loss_type": "seg_auxce_loss",
"loss_type": "dsnce_loss",
"loss_weights": {
"aux_loss": 0.4,
"seg_loss": 1.0
"ce_loss": {
"ce_loss": 1.0
},
"dsnce_loss": {
"ce_loss": 1.0, "dsn_ce_loss": 0.4
},
"dsnohemce_loss": {
"ohem_ce_loss": 1.0, "dsn_ce_loss": 0.4
},
},
"params": {
"ce_weight": [0.8373, 0.9180, 0.8660, 1.0345, 1.0166, 0.9969, 0.9754,
1.0489, 0.8786, 1.0023, 0.9539, 0.9843, 1.1116, 0.9037,
1.0865, 1.0955, 1.0865, 1.1529, 1.0507],
"ce_reduction": "mean",
"ce_ignore_index": -1
"ce_loss": {
"weight": [0.8373, 0.9180, 0.8660, 1.0345, 1.0166, 0.9969, 0.9754,
1.0489, 0.8786, 1.0023, 0.9539, 0.9843, 1.1116, 0.9037,
1.0865, 1.0955, 1.0865, 1.1529, 1.0507],
"reduction": "mean",
"ignore_index": -1
}
}
}
}
10 changes: 7 additions & 3 deletions configs/seg/cityscapes/fs_denseaspp_cityscapes_seg.conf
Original file line number Diff line number Diff line change
Expand Up @@ -122,10 +122,14 @@
"max_iters": 40000
},
"loss": {
"loss_type": "seg_ce_loss",
"loss_type": "ce_loss",
"loss_weights": {
"aux_loss": 0.4,
"seg_loss": 1.0
"ce_loss": {
"ce_loss": 1.0
},
"ohemce_loss": {
"ohem_ce_loss": 1.0
},
},
"params": {
"ce_weight": [0.8373, 0.9180, 0.8660, 1.0345, 1.0166, 0.9969, 0.9754,
Expand Down
27 changes: 18 additions & 9 deletions configs/seg/cityscapes/fs_pspnet_cityscapes_seg.conf
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@
"backbone": "deepbase_resnet101_dilated8",
"multi_grid": [1, 1, 1],
"model_name": "pspnet",
"norm_type": "sync_batchnorm",
"norm_type": "batchnorm",
"stride": 8,
"checkpoints_name": "fs_pspnet_cityscapes_seg",
"checkpoints_dir": "./checkpoints/seg/cityscapes"
Expand Down Expand Up @@ -122,17 +122,26 @@
"max_iters": 40000
},
"loss": {
"loss_type": "seg_auxce_loss",
"loss_type": "dsnce_loss",
"loss_weights": {
"aux_loss": 0.4,
"seg_loss": 1.0
"ce_loss": {
"ce_loss": 1.0
},
"dsnce_loss": {
"ce_loss": 1.0, "dsn_ce_loss": 0.4
},
"dsnohemce_loss": {
"ohem_ce_loss": 1.0, "dsn_ce_loss": 0.4
},
},
"params": {
"ce_weight": [0.8373, 0.9180, 0.8660, 1.0345, 1.0166, 0.9969, 0.9754,
1.0489, 0.8786, 1.0023, 0.9539, 0.9843, 1.1116, 0.9037,
1.0865, 1.0955, 1.0865, 1.1529, 1.0507],
"ce_reduction": "mean",
"ce_ignore_index": -1
"ce_loss": {
"weight": [0.8373, 0.9180, 0.8660, 1.0345, 1.0166, 0.9969, 0.9754,
1.0489, 0.8786, 1.0023, 0.9539, 0.9843, 1.1116, 0.9037,
1.0865, 1.0955, 1.0865, 1.1529, 1.0507],
"reduction": "mean",
"ignore_index": -1
}
}
}
}
2 changes: 1 addition & 1 deletion datasets/cls/loader/default_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import os
import torch.utils.data as data

from exts.tools.parallel import DataContainer
from exts.tools.parallel.data_container import DataContainer
from tools.helper.image_helper import ImageHelper
from tools.util.logger import Logger as Log

Expand Down
2 changes: 1 addition & 1 deletion datasets/det/loader/default_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import numpy as np
import torch.utils.data as data

from exts.tools.parallel import DataContainer
from exts.tools.parallel.data_container import DataContainer
from tools.helper.json_helper import JsonHelper
from tools.helper.image_helper import ImageHelper
from tools.util.logger import Logger as Log
Expand Down
2 changes: 1 addition & 1 deletion datasets/det/loader/fasterrcnn_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import numpy as np
import torch.utils.data as data

from exts.tools.parallel import DataContainer
from exts.tools.parallel.data_container import DataContainer
from tools.helper.json_helper import JsonHelper
from tools.helper.image_helper import ImageHelper
from tools.util.logger import Logger as Log
Expand Down
2 changes: 1 addition & 1 deletion datasets/gan/loader/cyclegan_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import random
import torch.utils.data as data

from exts.tools.parallel import DataContainer
from exts.tools.parallel.data_container import DataContainer
from tools.helper.image_helper import ImageHelper


Expand Down
2 changes: 1 addition & 1 deletion datasets/gan/loader/default_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import os
import torch.utils.data as data

from exts.tools.parallel import DataContainer
from exts.tools.parallel.data_container import DataContainer
from tools.helper.image_helper import ImageHelper
from tools.util.logger import Logger as Log

Expand Down
2 changes: 1 addition & 1 deletion datasets/gan/loader/facegan_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import random
import torch.utils.data as data

from exts.tools.parallel import DataContainer
from exts.tools.parallel.data_container import DataContainer
from tools.helper.image_helper import ImageHelper
from tools.util.logger import Logger as Log

Expand Down
2 changes: 1 addition & 1 deletion datasets/pose/loader/default_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch
import torch.utils.data as data

from exts.tools.parallel import DataContainer
from exts.tools.parallel.data_container import DataContainer
from datasets.pose.utils.heatmap_generator import HeatmapGenerator
from tools.helper.json_helper import JsonHelper
from tools.helper.image_helper import ImageHelper
Expand Down
2 changes: 1 addition & 1 deletion datasets/pose/loader/openpose_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import numpy as np
import torch.utils.data as data

from exts.tools.parallel import DataContainer
from exts.tools.parallel.data_container import DataContainer
from datasets.pose.utils.heatmap_generator import HeatmapGenerator
from datasets.pose.utils.paf_generator import PafGenerator
from tools.helper.json_helper import JsonHelper
Expand Down
2 changes: 1 addition & 1 deletion datasets/seg/loader/default_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numpy as np
from torch.utils import data

from exts.tools.parallel import DataContainer
from exts.tools.parallel.data_container import DataContainer
from tools.helper.image_helper import ImageHelper
from tools.util.logger import Logger as Log

Expand Down
2 changes: 1 addition & 1 deletion datasets/test/loader/default_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import os
import torch.utils.data as data

from exts.tools.parallel import DataContainer
from exts.tools.parallel.data_container import DataContainer
from tools.helper.file_helper import FileHelper
from tools.helper.image_helper import ImageHelper
from tools.util.logger import Logger as Log
Expand Down
2 changes: 1 addition & 1 deletion datasets/test/loader/facegan_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import os
import torch.utils.data as data

from exts.tools.parallel import DataContainer
from exts.tools.parallel.data_container import DataContainer
from tools.helper.json_helper import JsonHelper
from tools.helper.image_helper import ImageHelper
from tools.util.logger import Logger as Log
Expand Down
2 changes: 1 addition & 1 deletion datasets/test/loader/json_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import os
import torch.utils.data as data

from exts.tools.parallel import DataContainer
from exts.tools.parallel.data_container import DataContainer
from tools.helper.json_helper import JsonHelper
from tools.helper.image_helper import ImageHelper
from tools.util.logger import Logger as Log
Expand Down
2 changes: 1 addition & 1 deletion datasets/test/loader/list_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import os
import torch.utils.data as data

from exts.tools.parallel import DataContainer
from exts.tools.parallel.data_container import DataContainer
from tools.helper.image_helper import ImageHelper
from tools.util.logger import Logger as Log

Expand Down
2 changes: 1 addition & 1 deletion datasets/tools/collate.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from torch.utils.data.dataloader import default_collate
from torch._six import string_classes, int_classes

from exts.tools.parallel import DataContainer
from exts.tools.parallel.data_container import DataContainer
from tools.helper.tensor_helper import TensorHelper
from tools.util.logger import Logger as Log

Expand Down
10 changes: 0 additions & 10 deletions exts/tools/parallel/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +0,0 @@
from .data_container import DataContainer
from .data_parallel import DataParallelModel, DataParallelCriterion
from .distributed import MMDistributedDataParallel
from .scatter_gather import scatter, scatter_kwargs

__all__ = [
'DataContainer', 'MMDistributedDataParallel',
'DataParallelModel', 'DataParallelCriterion',
'scatter', 'scatter_kwargs'
]
4 changes: 2 additions & 2 deletions exts/tools/parallel/data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class ParallelModel(DataParallel):
>>> y = net(x)
"""
def __init__(self, module, device_ids=None, output_device=None, dim=0, gather_=True):
super(DataParallelModel, self).__init__(module, device_ids, output_device, dim)
super(ParallelModel, self).__init__(module, device_ids, output_device, dim)
self.gather_ = gather_

def gather(self, outputs, output_device):
Expand All @@ -63,7 +63,7 @@ class ParallelCriterion(DataParallel):
>>> loss = criterion(y, target)
"""
def __init__(self, module, device_ids=None, output_device=None, dim=0):
super(DataParallelCriterion, self).__init__(module, device_ids, output_device, dim)
super(ParallelCriterion, self).__init__(module, device_ids, output_device, dim)

def scatter(self, inputs, kwargs, device_ids):
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
Expand Down
9 changes: 5 additions & 4 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,8 @@ def str2bool(v):
dest='network.resume_continue', help='Whether to continue training.')
parser.add_argument('--resume_val', type=str2bool, nargs='?', default=True,
dest='network.resume_val', help='Whether to validate during resume.')
parser.add_argument('--gathered', type=str2bool, nargs='?', default=True,
dest='network.gathered', help='Whether to gather the output of model.')
parser.add_argument('--loss_balance', type=str2bool, nargs='?', default=False,
dest='network.loss_balance', help='Whether to balance GPU usage.')
parser.add_argument('--gather', type=str2bool, nargs='?', default=True,
dest='network.gather', help='Whether to gather the output of model.')

# *********** Params for solver. **********
parser.add_argument('--optim_method', default=None, type=str,
Expand Down Expand Up @@ -149,6 +147,9 @@ def str2bool(v):
if configer.get('network', 'norm_type') is None:
configer.update('network.norm_type', 'batchnorm')

if len(configer.get('gpu')) == 1 or len(range(torch.cuda.device_count())) == 1:
configer.update('network.gather', True)

if configer.get('phase') == 'train':
assert len(configer.get('gpu')) > 1 or 'sync' not in configer.get('network', 'norm_type')

Expand Down
12 changes: 6 additions & 6 deletions model/seg/loss/ce_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@ class CELoss(nn.Module):
def __init__(self, configer=None):
super(CELoss, self).__init__()
self.configer = configer
weight = self.configer.get('loss.params.ce_weight', default=None)
weight = torch.FloatTensor(weight).cuda() if weight is not None else weight
reduction = self.configer.get('loss.params.ce_reduction', default='mean')
ignore_index = self.configer.get('loss.params.ce_ignore_index', default=-100)
self.ce_loss = nn.CrossEntropyLoss(weight=weight, ignore_index=ignore_index, reduction=reduction)
weight = self.configer.get('loss.params.ce_loss.weight', default=None)
self.weight = torch.FloatTensor(weight) if weight is not None else weight
self.reduction = self.configer.get('loss.params.ce_loss.reduction', default='mean')
self.ignore_index = self.configer.get('loss.params.ce_loss.ignore_index', default=-100)

def forward(self, input, target):
target = self._scale_target(target, (input.size(2), input.size(3)))
loss = self.ce_loss(input, target)
loss = F.cross_entropy(input, target, weight=self.weight.to(input.device),
ignore_index=self.ignore_index, reduction=self.reduction)
return loss

@staticmethod
Expand Down
14 changes: 4 additions & 10 deletions model/seg/loss/encode_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,11 @@ class EncodeLoss(nn.Module):
def __init__(self, configer):
super(EncodeLoss, self).__init__()
self.configer = configer
weight = None
if self.configer.exists('loss', 'params') and 'enc_weight' in self.configer.get('loss', 'params'):
weight = self.configer.get('loss', 'params')['enc_weight']
weight = torch.FloatTensor(weight).cuda()

reduction = 'mean'
if self.configer.exists('loss', 'params') and 'enc_reduction' in self.configer.get('loss', 'params'):
reduction = self.configer.get('loss', 'params')['enc_reduction']

weight = self.configer.get('loss.params.encode_loss.weight', default=None)
weight = torch.FloatTensor(weight).cuda() if weight is not None else weight
reduction = self.configer.get('loss.params.encode_loss.reduction', default='mean')
self.bce_loss = nn.BCELoss(weight, reduction=reduction)
self.grid_size = self.configer.get('loss', 'params')['enc_grid_size']
self.grid_size = self.configer.get('loss.params.encode_loss.grid_size', default=[1, 1])

def forward(self, preds, targets):
if len(targets.size()) == 2:
Expand Down
26 changes: 9 additions & 17 deletions model/seg/loss/ohem_ce_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,13 @@ class OhemCELoss(nn.Module):
def __init__(self, configer):
super(OhemCELoss, self).__init__()
self.configer = configer
self.thresh = self.configer.get('loss', 'params')['ohem_thresh']
self.min_kept = max(1, self.configer.get('loss', 'params')['ohem_minkeep'])
weight = None
if self.configer.exists('loss', 'params') and 'ce_weight' in self.configer.get('loss', 'params'):
weight = self.configer.get('loss', 'params')['ce_weight']
weight = torch.FloatTensor(weight).cuda()

self.reduction = 'mean'
if self.configer.exists('loss', 'params') and 'ce_reduction' in self.configer.get('loss', 'params'):
self.reduction = self.configer.get('loss', 'params')['ce_reduction']

ignore_index = -100
if self.configer.exists('loss', 'params') and 'ce_ignore_index' in self.configer.get('loss', 'params'):
ignore_index = self.configer.get('loss', 'params')['ce_ignore_index']

weight = self.configer.get('loss.params.ohem_ce_loss.weight', default=None)
weight = torch.FloatTensor(weight) if weight is not None else weight
reduction = self.configer.get('loss.params.ohem_ce_loss.reduction', default='mean')
ignore_index = self.configer.get('loss.params.ohem_ce_loss.ignore_index', default=-100)
self.thresh = self.configer.get('loss.params.ohem_ce_loss.thresh', default=0.7)
self.min_kept = max(1, self.configer.get('loss.params.ohem_ce_loss.minkeep', default=1))
self.ignore_label = ignore_index
self.ce_loss = nn.CrossEntropyLoss(weight=weight, ignore_index=ignore_index, reduction='none')

def forward(self, predict, target):
"""
Expand All @@ -47,7 +37,9 @@ def forward(self, predict, target):
sort_prob, sort_indices = prob.contiguous().view(-1, )[mask].contiguous().sort()
min_threshold = sort_prob[min(self.min_kept, sort_prob.numel() - 1)] if sort_prob.numel() > 0 else 0.0
threshold = max(min_threshold, self.thresh)
loss_matirx = self.ce_loss(predict, target).contiguous().view(-1, )
loss_matrix = F.cross_entropy(predict, target, weight=self.weight.to(input.device),
ignore_index=self.ignore_index, reduction='none')
loss_matirx = loss_matrix.contiguous().view(-1, )
sort_loss_matirx = loss_matirx[mask][sort_indices]
select_loss_matrix = sort_loss_matirx[sort_prob < threshold]
if self.reduction == 'sum' or select_loss_matrix.numel() == 0:
Expand Down
Loading

0 comments on commit 7dfb954

Please sign in to comment.