diff --git a/models/resnet_in.py b/models/resnet_in.py index 31c0c7d..fa2d950 100644 --- a/models/resnet_in.py +++ b/models/resnet_in.py @@ -107,7 +107,7 @@ def forward(self, x): class ResNet(nn.Module): - def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, + def __init__(self, block, layers, initial_kernel_size=7, num_classes=1000, zero_init_residual=False, groups=1, width_per_group=64, replace_stride_with_dilation=None, norm_layer=None): super(ResNet, self).__init__() @@ -116,6 +116,8 @@ def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, self._norm_layer = norm_layer self.inplanes = 64 + self.initial_kernel_size = initial_kernel_size + self.num_classes = num_classes self.dilation = 1 if replace_stride_with_dilation is None: # each element in the tuple indicates if we should replace @@ -126,7 +128,7 @@ def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) self.groups = groups self.base_width = width_per_group - self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, + self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=self.initial_kernel_size, stride=2, padding=3, bias=False) self.bn1 = norm_layer(self.inplanes) self.relu = nn.ReLU(inplace=True) @@ -139,7 +141,7 @@ def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2]) self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) - self.fc = nn.Linear(512 * block.expansion, num_classes) + self.fc = nn.Linear(512 * block.expansion, self.num_classes) for m in self.modules(): if isinstance(m, nn.Conv2d): @@ -209,7 +211,7 @@ def _resnet(arch, block, layers, pretrained, progress, **kwargs): return model -def resnet18(pretrained=False, progress=True, **kwargs): +def ResNet18(pretrained=False, progress=True, **kwargs): r"""ResNet-18 model from `"Deep Residual Learning for Image Recognition" `_ Args: @@ -220,7 +222,7 @@ def resnet18(pretrained=False, progress=True, **kwargs): **kwargs) -def resnet34(pretrained=False, progress=True, **kwargs): +def ResNet34(pretrained=False, progress=True, **kwargs): r"""ResNet-34 model from `"Deep Residual Learning for Image Recognition" `_ Args: @@ -231,7 +233,7 @@ def resnet34(pretrained=False, progress=True, **kwargs): **kwargs) -def resnet50(pretrained=False, progress=True, **kwargs): +def ResNet50(pretrained=False, progress=True, **kwargs): r"""ResNet-50 model from `"Deep Residual Learning for Image Recognition" `_ Args: @@ -242,7 +244,7 @@ def resnet50(pretrained=False, progress=True, **kwargs): **kwargs) -def resnet101(pretrained=False, progress=True, **kwargs): +def ResNet101(pretrained=False, progress=True, **kwargs): r"""ResNet-101 model from `"Deep Residual Learning for Image Recognition" `_ Args: