-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Zhang Yuan
committed
Feb 5, 2018
1 parent
a4ab02a
commit 19b13ef
Showing
6 changed files
with
290 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from .fcn16s import * | ||
from .fcn32s import * | ||
from .fcn8s import * | ||
from .u_net import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
import os | ||
|
||
root = "../pretrained" | ||
res101_path = os.path.join(root, 'ResNet', 'resnet101-5d3b4d8f.pth') | ||
res152_path = os.path.join(root, 'ResNet', 'resnet152-b121ed2d.pth') | ||
inception_v3_path = os.path.join(root, 'Inception', 'inception_v3_google-1a9a5a14.pth') | ||
vgg19_bn_path = os.path.join(root, 'VggNet', 'vgg19_bn-c79401a0.pth') | ||
vgg16_path = os.path.join(root, 'VggNet', 'vgg16-397923af.pth') | ||
dense201_path = os.path.join(root, 'DenseNet', 'densenet201-4c113574.pth') | ||
|
||
''' | ||
vgg16 trained using caffe | ||
visit this (https://github.com/jcjohnson/pytorch-vgg) to download the converted vgg16 | ||
''' | ||
vgg16_caffe_path = os.path.join(root, 'VggNet', 'vgg16-caffe.pth') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
import torch | ||
from torch import nn | ||
from torchvision import models | ||
|
||
from ..utils import get_upsampling_weight | ||
from .config import vgg16_caffe_path | ||
|
||
class FCN16VGG(nn.Module): | ||
|
||
def __init__(self, num_classes, pretrained=True): | ||
|
||
super(FCN16VGG, self).__init__() | ||
vgg = models.vgg16() | ||
|
||
if pretrained: | ||
vgg.load_state_dict(torch.load(vgg16_caffe_path)) | ||
|
||
features, classifier = list(vgg.features.children()), list(vgg.classifier.children()) | ||
|
||
features[0].padding = (100, 100) | ||
|
||
for f in features: | ||
if 'MaxPool' in f.__class__.__name__: | ||
f.ceil_mode = True | ||
elif 'ReLU' in f.__class__.__name__: | ||
f.inplace = True | ||
|
||
|
||
self.features4 = nn.Sequential(*features[: 24]) | ||
self.features5 = nn.Sequential(*features[24:]) | ||
|
||
self.score_pool4 = nn.Conv2d(512, num_classes, kernel_size=1) | ||
self.score_pool4.weight.data.zero_() | ||
self.score_pool4.bias.data.zero_() | ||
|
||
fc6 = nn.Conv2d(512, 4096, kernel_size=7) | ||
fc6.weight.data.copy_(classifier[0].weight.data.view(4096, 512, 7, 7)) | ||
fc6.bias.data.copy_(classifier[0].bias.data) | ||
fc7 = nn.Conv2d(4096, 4096, kernel_size=1) | ||
fc7.weight.data.copy_(classifier[3].weight.data.view(4096, 4096, 1, 1)) | ||
fc7.bias.data.copy_(classifier[3].bias.data) | ||
score_fr = nn.Conv2d(4096, num_classes, kernel_size=1) | ||
score_fr.weight.data.zero_() | ||
score_fr.bias.data.zero_() | ||
self.score_fr = nn.Sequential(fc6, nn.ReLU(inplace=True), nn.Dropout(), fc7, nn.ReLU(inplace=True), nn.Dropout, score_fr) | ||
|
||
self.upscore2 = nn.ConvTranspose2d(num_classes, num_classes, kernel_size=4, stride=2, bias=False) | ||
self.upscore16 = nn.ConvTranspose2d(num_classes, num_classes, kernel_size=32, stride=16, bias=False) | ||
self.upscore2.weight.data.copy_(get_upsampling_weight(num_classes, num_classes, 4)) | ||
self.upscore16.weight.data.copy_(get_upsampling_weight(num_classes, num_classes, 32)) | ||
|
||
def forward(self, x): | ||
x_size = x.size() | ||
pool4 = self.features4(x) | ||
pool5 = self.features5(pool4) | ||
|
||
score_fr = self.score_fr(pool5) | ||
upscore2 = self.upscore2(score_fr) | ||
|
||
score_pool4 = self.score_pool4(0.01 * pool4) | ||
upscore16 = self.upscore16(score_pool4[:, :, 5: (5 + upscore2.size()[2]), 5: (5 + upscore2.size()[3])] + upscore2) | ||
|
||
return upscore16[:, :, 27: (27 + x_size[2]), 27: (27 + x_size[3])].contiguous() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
import torch | ||
from torch import nn | ||
from torchvision import models | ||
|
||
from ..utils import get_upsampling_weight | ||
from .config import vgg16_caffe_path | ||
|
||
class FCN32VGG(nn.Module): | ||
|
||
def __init__(self, num_classes, pretrained=True): | ||
super(FCN32VGG, self).__init__() | ||
vgg = models.vgg16() | ||
if pretrained: | ||
vgg.load_state_dict(torch.load(vgg16_caffe_path)) | ||
|
||
features, classifier = list(vgg.features.children()), list(vgg.classifier.children()) | ||
|
||
features[0].padding = (100, 100) | ||
|
||
for f in features: | ||
if 'MaxPool' in f.__class__.__name__: | ||
f.ceil_mode = True | ||
elif 'ReLU' in f.__class__.__name__: | ||
f.inplace = True | ||
|
||
self.features5 = nn.Sequential(*features) | ||
|
||
fc6 = nn.Conv2d(512, 4096, kernel_size=7) | ||
fc6.weight.data.copy_(classifier[0].weight.data.view(4096, 512, 7, 7)) | ||
fc6.bias.data.copy_(classifier[0].bias.data) | ||
fc7 = nn.Conv2d(4096, 4096, kernel_size=1) | ||
fc7.weight.data.copy_(classifier[3].weight.data.view(4096, 4096, 1, 1)) | ||
fc7.bias.data.copy_(classifier[3].bias.data) | ||
score_fr = nn.Conv2d(4096, num_classes, kernel_size=1) | ||
score_fr.weight.data.zero_() | ||
score_fr.bias.data.zero_() | ||
self.score_fr = nn.Sequential(fc6, nn.ReLU(inplace=True), nn.Dropout, fc7, nn.ReLU(inplace=True), nn.Dropout, score_fr) | ||
|
||
self.upscore = nn.ConvTranspose2d(num_classes, num_classes, kernel_size=64, stride=32, bias=False) | ||
self.upscore.weight.data.copy_(get_upsampling_weight(num_classes, num_classes, 64)) | ||
|
||
|
||
def forward(self, x): | ||
x_size = x.size() | ||
pool5 = self.features5(x) | ||
score_fr = self.score_fr(pool5) | ||
upscore = self.upscore(score_fr) | ||
return upscore[:, :, 19: (19 + x_size[2]), 19: (19 + x_size[3])].contiguous() | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
import torch | ||
from torch import nn | ||
from torchvision import models | ||
|
||
from ..utils import get_upsampling_weight | ||
from .config import vgg16_path, vgg16_caffe_path | ||
|
||
class FCN8s(nn.Module): | ||
def __init__(self, num_classes, pretrained=True, caffe=False): | ||
super(FCN8s, self).__init__() | ||
vgg = models.vgg16() | ||
if pretrained: | ||
if caffe: | ||
#Load the pretrained vgg16 used by the paper's author | ||
vgg.load_state_dict(torch.load(vgg16_caffe_path)) | ||
else: | ||
vgg.load_state_dict(torch.load(vgg16_path)) | ||
|
||
features, classifier = list(vgg.features.children()), list(vgg.classifier.children()) | ||
|
||
''' | ||
100 padding for 2 reasons: | ||
1. support very small input size | ||
2. allow cropping in order to match size of different layers' feature maps | ||
''' | ||
features[0].padding = (100, 100) | ||
|
||
for f in features: | ||
if 'MaxPool' in f.__class__.__name__: | ||
f.ceil_mode = True | ||
elif 'ReLU' in f.__class__.__name__: | ||
f.inplace = True | ||
|
||
self.features3 = nn.Sequential(*features[: 17]) | ||
self.features4 = nn.Sequential(*features[17: 24]) | ||
self.features5 = nn.Sequential(*features[24:]) | ||
|
||
self.score_pool3 = nn.Conv2d(256, num_classes, kernel_size=1) | ||
self.score_pool4 = nn.Conv2d(512, num_classes, kernel_size=1) | ||
self.score_pool3.weight.data.zero_() | ||
self.score_pool3.bias.data.zero_() | ||
self.score_pool4.weight.data.zero_() | ||
self.score_pool4.bias.data.zero_() | ||
|
||
fc6 = nn.Conv2d(512, 4096, kernel_size=7) | ||
fc6.weight.data.copy_(classifier[0].weight.data.view(4096, 512, 7, 7)) | ||
fc6.bias.data.copy_(classifier[0].bias.data) | ||
fc7 = nn.Conv2d(4096, 4096, kernel_size=1) | ||
fc7.weight.data.copy_(classifier[3].weight.data.view(4096, 4096, 1, 1)) | ||
fc7.bias.data.copy_(classifier[3].bias.data) | ||
score_fr = nn.Conv2d(4096, num_classes, kernel_size=1) | ||
score_fr.weight.data.zero_() | ||
score_fr.bias.data.zero_() | ||
self.score_fr = nn.Sequential(fc6, nn.ReLU(inplace=True), nn.Dropout(), fc7, nn.ReLU(inplace=True), nn.Dropout(), score_fr) | ||
|
||
self.upscore2 = nn.ConvTranspose2d(num_classes, num_classes, kernel_size=4, stride=2, bias=False) | ||
self.upscore_pool4 = nn.ConvTranspose2d(num_classes, num_classes, kernel_size=4, stride=2, bias=False) | ||
self.upscore8 = nn.ConvTranspose2d(num_classes, num_classes, kernel_size=16, stride=8, bias=False) | ||
self.upscore2.weight.data.copy_(get_upsampling_weight(num_classes, num_classes, 4)) | ||
self.upscore_pool4.weight.data.copy_(get_upsampling_weight(num_classes, num_classes, 4)) | ||
self.upscore8.weight.data.copy_(get_upsampling_weight(num_classes, num_classes, 16)) | ||
|
||
def forward(self, x): | ||
x_size = x.size() | ||
pool3 = self.features3(x) | ||
pool4 = self.features4(pool3) | ||
pool5 = self.features5(pool4) | ||
|
||
score_fr = self.score_fr(pool5) | ||
upscore2 = self.upscore2(score_fr) | ||
|
||
score_pool4 = self.score_pool4(0.01 * pool4) | ||
upscore_pool4 = self.upscore_pool4(score_pool4[:, :, 5: (5 + upscore2.size()[2]), 5: (5 + upscore2.size()[3])] + upscore2) | ||
|
||
score_pool3 = self.score_pool3(0.0001 * pool3) | ||
upscore8 = self.upscore8(score_pool3[:, :, 9: (9 + upscore_pool4.size()[2]), 9: (9 + upscore_pool4.size()[3])] + upscore_pool4) | ||
return upscore8[:, :, 31: (31 + x_size[2]), 31: (31 + x_size[3])].contiguous() | ||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
import torch | ||
import torch.nn.functional as F | ||
from torch import nn | ||
|
||
from ..utils import initialize_weights | ||
|
||
class _EncoderBlock(nn.Module): | ||
|
||
def __init__(self, in_channels, out_channels, dropout=False): | ||
super(_EncoderBlock, self).__init__() | ||
layers = [ | ||
nn.Conv2d(in_channels, out_channels, kernel_size=3), | ||
nn.BatchNorm2d(out_channels), | ||
nn.ReLU(inplace=True), | ||
nn.Conv2d(out_channels, out_channels, kernel_size=3), | ||
nn.BatchNorm2d(out_channels), | ||
nn.ReLU(inplace=True), | ||
] | ||
if dropout: | ||
layers.append(nn.Dropout()) | ||
layers.append(nn.MaxPool2d(kernel_size=2, stride=2)) | ||
self.encode = nn.Sequential(*layers) | ||
|
||
def forward(self, x): | ||
return self.encode(x) | ||
|
||
class _DecoderBlock(nn.Module): | ||
|
||
def __init__(self, in_channels, middle_channels, out_channels): | ||
super(_DecoderBlock, self).__init__() | ||
self.decode = nn.Sequential( | ||
nn.Conv2d(in_channels, middle_channels, kernel_size=3), | ||
nn.BatchNorm2d(middle_channels), | ||
nn.ReLU(inplace=True), | ||
nn.Conv2d(middle_channels, middle_channels, kernel_size=3), | ||
nn.BatchNorm2d(middle_channels), | ||
nn.ReLU(inplace=True), | ||
nn.ConvTranspose2d(middle_channels, out_channels, kernel_size=2, stride=2), | ||
) | ||
|
||
def forward(self, x): | ||
return self.decode(x) | ||
|
||
class UNet(nn.Module): | ||
|
||
def __init__(self, num_classes): | ||
super(UNet, self).__init__() | ||
self.enc1 = _EncoderBlock(3, 64) | ||
self.enc2 = _EncoderBlock(64, 128) | ||
self.enc3 = _EncoderBlock(128, 256) | ||
self.enc4 = _EncoderBlock(256, 512, dropout=True) | ||
self.center = _DecoderBlock(512, 1024, 512) | ||
self.dec4 = _DecoderBlock(1024, 512, 256) | ||
self.dec3 = _DecoderBlock(512, 256, 128) | ||
self.dec2 = _DecoderBlock(256, 128, 64) | ||
self.dec1 = nn.Sequential( | ||
nn.Conv2d(128, 64, kernel_size=3), | ||
nn.BatchNorm2d(64), | ||
nn.ReLU(inplace=True), | ||
nn.Conv2d(64, 64, kernel_size=3), | ||
nn.BatchNorm2d(64), | ||
nn.ReLU(inplace=True), | ||
) | ||
self.final = nn.Conv2d(64, num_classes, kernel_size=1) | ||
initialize_weights(self) | ||
|
||
def forward(self, x): | ||
enc1 = self.enc1(x) | ||
enc2 = self.enc2(enc1) | ||
enc3 = self.enc3(enc2) | ||
enc4 = self.enc4(enc3) | ||
center = self.center(enc4) | ||
dec4 = self.dec4(torch.cat([center, F.upsample(enc4, center.size()[2:], mode='bilinear')], 1)) | ||
dec3 = self.dec3(torch.cat([dec4, F.upsample(enc3, dec4.size()[2:], mode='bilinear')], 1)) | ||
dec2 = self.dec2(torch.cat([dec3, F.upsample(enc2, dec3.size()[2:], mode='bilinear')], 1)) | ||
dec1 = self.dec1(torch.cat([dec2, F.upsample(enc1, dec2.size()[2:], mode='bilinear')], 1)) | ||
final = self.final(dec1) | ||
return F.upsample(final, x.size()[2:], mode='bilinear') |