Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update resnet_in.py #5

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions models/resnet_in.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def forward(self, x):

class ResNet(nn.Module):

def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
def __init__(self, block, layers, initial_kernel_size=7, num_classes=1000, zero_init_residual=False,
groups=1, width_per_group=64, replace_stride_with_dilation=None,
norm_layer=None):
super(ResNet, self).__init__()
Expand All @@ -116,6 +116,8 @@ def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
self._norm_layer = norm_layer

self.inplanes = 64
self.initial_kernel_size = initial_kernel_size
self.num_classes = num_classes
self.dilation = 1
if replace_stride_with_dilation is None:
# each element in the tuple indicates if we should replace
Expand All @@ -126,7 +128,7 @@ def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
self.groups = groups
self.base_width = width_per_group
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=self.initial_kernel_size, stride=2, padding=3,
bias=False)
self.bn1 = norm_layer(self.inplanes)
self.relu = nn.ReLU(inplace=True)
Expand All @@ -139,7 +141,7 @@ def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
dilate=replace_stride_with_dilation[2])
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512 * block.expansion, num_classes)
self.fc = nn.Linear(512 * block.expansion, self.num_classes)

for m in self.modules():
if isinstance(m, nn.Conv2d):
Expand Down Expand Up @@ -209,7 +211,7 @@ def _resnet(arch, block, layers, pretrained, progress, **kwargs):
return model


def resnet18(pretrained=False, progress=True, **kwargs):
def ResNet18(pretrained=False, progress=True, **kwargs):
r"""ResNet-18 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Args:
Expand All @@ -220,7 +222,7 @@ def resnet18(pretrained=False, progress=True, **kwargs):
**kwargs)


def resnet34(pretrained=False, progress=True, **kwargs):
def ResNet34(pretrained=False, progress=True, **kwargs):
r"""ResNet-34 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Args:
Expand All @@ -231,7 +233,7 @@ def resnet34(pretrained=False, progress=True, **kwargs):
**kwargs)


def resnet50(pretrained=False, progress=True, **kwargs):
def ResNet50(pretrained=False, progress=True, **kwargs):
r"""ResNet-50 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Args:
Expand All @@ -242,7 +244,7 @@ def resnet50(pretrained=False, progress=True, **kwargs):
**kwargs)


def resnet101(pretrained=False, progress=True, **kwargs):
def ResNet101(pretrained=False, progress=True, **kwargs):
r"""ResNet-101 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Args:
Expand Down