From 1683bbac652f0cdbba38b3fd75c0c7071865ad0f Mon Sep 17 00:00:00 2001 From: Jakub Kaczmarzyk Date: Sun, 15 Jan 2023 13:19:09 -0500 Subject: [PATCH] FIX: use inceptionv4 from timm (#69) * use inceptionv4 from timm This changes the model name from inceptionv4 to inception_v4 and allows us to remove the inceptionv4 implementation from the codebase. * use inception_v4nobn with help of timm Use the timm implementation but remove batchnorm (and add bias to the conv layer before batchnorm). This changes the name from inceptionv4nobn to inception_v4nobn. * rm unused default_cfgs * ignore timm types --- setup.cfg | 2 +- tests/test_all.py | 6 +- wsinfer/_modellib/inceptionv4.py | 306 ------------------ wsinfer/_modellib/inceptionv4_no_batchnorm.py | 190 +++++++---- wsinfer/_modellib/models.py | 6 +- .../modeldefs/inceptionv4_tcga-brca-v1.yaml | 2 +- .../inceptionv4nobn_tcga-tils-v1.yaml | 2 +- 7 files changed, 127 insertions(+), 387 deletions(-) delete mode 100644 wsinfer/_modellib/inceptionv4.py diff --git a/setup.cfg b/setup.cfg index 1c26af8..f55b906 100644 --- a/setup.cfg +++ b/setup.cfg @@ -89,7 +89,7 @@ ignore_missing_imports = True ignore_missing_imports = True [mypy-pandas] ignore_missing_imports = True -[mypy-timm] +[mypy-timm.*] ignore_missing_imports = True [mypy-scipy.stats] ignore_missing_imports = True diff --git a/tests/test_all.py b/tests/test_all.py index ef5637f..1975cfd 100644 --- a/tests/test_all.py +++ b/tests/test_all.py @@ -216,9 +216,9 @@ def test_cli_run_args(tmp_path: Path): 350, 144, ), - # Inceptionv4 TCGA-BRCA-v1 + # Inception_v4 TCGA-BRCA-v1 ( - "inceptionv4", + "inception_v4", "TCGA-BRCA-v1", ["notumor", "tumor"], [0.9564113020896912, 0.043588679283857346], @@ -227,7 +227,7 @@ def test_cli_run_args(tmp_path: Path): ), # Inceptionv4nobn TCGA-TILs-v1 ( - "inceptionv4nobn", + "inception_v4nobn", "TCGA-TILs-v1", ["notils", "tils"], [1.0, 3.427359524660334e-12], diff --git a/wsinfer/_modellib/inceptionv4.py b/wsinfer/_modellib/inceptionv4.py deleted file mode 100644 index b7f42a3..0000000 --- a/wsinfer/_modellib/inceptionv4.py +++ /dev/null @@ -1,306 +0,0 @@ -# BSD 3-Clause License -# -# Copyright (c) 2017, Remi Cadene -# All rights reserved. -# -# Downloaded from -# https://raw.githubusercontent.com/Cadene/pretrained-models.pytorch/e07fb68c317880e780eb5ca9c20cca00f2584878/pretrainedmodels/models/inceptionv4.py # noqa: E501 -# -# We downloaded this file here so we did not have to add pretrainedmodels as a -# dependency (we only use this module). - -import torch -import torch.nn as nn -import torch.nn.functional as F - - -class BasicConv2d(nn.Module): - def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0): - super(BasicConv2d, self).__init__() - self.conv = nn.Conv2d( - in_planes, - out_planes, - kernel_size=kernel_size, - stride=stride, - padding=padding, - bias=False, - ) # verify bias false - self.bn = nn.BatchNorm2d( - out_planes, - eps=0.001, # value found in tensorflow - momentum=0.1, # default pytorch value - affine=True, - ) - self.relu = nn.ReLU(inplace=True) - - def forward(self, x): - x = self.conv(x) - x = self.bn(x) - x = self.relu(x) - return x - - -class Mixed_3a(nn.Module): - def __init__(self): - super(Mixed_3a, self).__init__() - self.maxpool = nn.MaxPool2d(3, stride=2) - self.conv = BasicConv2d(64, 96, kernel_size=3, stride=2) - - def forward(self, x): - x0 = self.maxpool(x) - x1 = self.conv(x) - out = torch.cat((x0, x1), 1) - return out - - -class Mixed_4a(nn.Module): - def __init__(self): - super(Mixed_4a, self).__init__() - - self.branch0 = nn.Sequential( - BasicConv2d(160, 64, kernel_size=1, stride=1), - BasicConv2d(64, 96, kernel_size=3, stride=1), - ) - - self.branch1 = nn.Sequential( - BasicConv2d(160, 64, kernel_size=1, stride=1), - BasicConv2d(64, 64, kernel_size=(1, 7), stride=1, padding=(0, 3)), - BasicConv2d(64, 64, kernel_size=(7, 1), stride=1, padding=(3, 0)), - BasicConv2d(64, 96, kernel_size=(3, 3), stride=1), - ) - - def forward(self, x): - x0 = self.branch0(x) - x1 = self.branch1(x) - out = torch.cat((x0, x1), 1) - return out - - -class Mixed_5a(nn.Module): - def __init__(self): - super(Mixed_5a, self).__init__() - self.conv = BasicConv2d(192, 192, kernel_size=3, stride=2) - self.maxpool = nn.MaxPool2d(3, stride=2) - - def forward(self, x): - x0 = self.conv(x) - x1 = self.maxpool(x) - out = torch.cat((x0, x1), 1) - return out - - -class Inception_A(nn.Module): - def __init__(self): - super(Inception_A, self).__init__() - self.branch0 = BasicConv2d(384, 96, kernel_size=1, stride=1) - - self.branch1 = nn.Sequential( - BasicConv2d(384, 64, kernel_size=1, stride=1), - BasicConv2d(64, 96, kernel_size=3, stride=1, padding=1), - ) - - self.branch2 = nn.Sequential( - BasicConv2d(384, 64, kernel_size=1, stride=1), - BasicConv2d(64, 96, kernel_size=3, stride=1, padding=1), - BasicConv2d(96, 96, kernel_size=3, stride=1, padding=1), - ) - - self.branch3 = nn.Sequential( - nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False), - BasicConv2d(384, 96, kernel_size=1, stride=1), - ) - - def forward(self, x): - x0 = self.branch0(x) - x1 = self.branch1(x) - x2 = self.branch2(x) - x3 = self.branch3(x) - out = torch.cat((x0, x1, x2, x3), 1) - return out - - -class Reduction_A(nn.Module): - def __init__(self): - super(Reduction_A, self).__init__() - self.branch0 = BasicConv2d(384, 384, kernel_size=3, stride=2) - - self.branch1 = nn.Sequential( - BasicConv2d(384, 192, kernel_size=1, stride=1), - BasicConv2d(192, 224, kernel_size=3, stride=1, padding=1), - BasicConv2d(224, 256, kernel_size=3, stride=2), - ) - - self.branch2 = nn.MaxPool2d(3, stride=2) - - def forward(self, x): - x0 = self.branch0(x) - x1 = self.branch1(x) - x2 = self.branch2(x) - out = torch.cat((x0, x1, x2), 1) - return out - - -class Inception_B(nn.Module): - def __init__(self): - super(Inception_B, self).__init__() - self.branch0 = BasicConv2d(1024, 384, kernel_size=1, stride=1) - - self.branch1 = nn.Sequential( - BasicConv2d(1024, 192, kernel_size=1, stride=1), - BasicConv2d(192, 224, kernel_size=(1, 7), stride=1, padding=(0, 3)), - BasicConv2d(224, 256, kernel_size=(7, 1), stride=1, padding=(3, 0)), - ) - - self.branch2 = nn.Sequential( - BasicConv2d(1024, 192, kernel_size=1, stride=1), - BasicConv2d(192, 192, kernel_size=(7, 1), stride=1, padding=(3, 0)), - BasicConv2d(192, 224, kernel_size=(1, 7), stride=1, padding=(0, 3)), - BasicConv2d(224, 224, kernel_size=(7, 1), stride=1, padding=(3, 0)), - BasicConv2d(224, 256, kernel_size=(1, 7), stride=1, padding=(0, 3)), - ) - - self.branch3 = nn.Sequential( - nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False), - BasicConv2d(1024, 128, kernel_size=1, stride=1), - ) - - def forward(self, x): - x0 = self.branch0(x) - x1 = self.branch1(x) - x2 = self.branch2(x) - x3 = self.branch3(x) - out = torch.cat((x0, x1, x2, x3), 1) - return out - - -class Reduction_B(nn.Module): - def __init__(self): - super(Reduction_B, self).__init__() - - self.branch0 = nn.Sequential( - BasicConv2d(1024, 192, kernel_size=1, stride=1), - BasicConv2d(192, 192, kernel_size=3, stride=2), - ) - - self.branch1 = nn.Sequential( - BasicConv2d(1024, 256, kernel_size=1, stride=1), - BasicConv2d(256, 256, kernel_size=(1, 7), stride=1, padding=(0, 3)), - BasicConv2d(256, 320, kernel_size=(7, 1), stride=1, padding=(3, 0)), - BasicConv2d(320, 320, kernel_size=3, stride=2), - ) - - self.branch2 = nn.MaxPool2d(3, stride=2) - - def forward(self, x): - x0 = self.branch0(x) - x1 = self.branch1(x) - x2 = self.branch2(x) - out = torch.cat((x0, x1, x2), 1) - return out - - -class Inception_C(nn.Module): - def __init__(self): - super(Inception_C, self).__init__() - - self.branch0 = BasicConv2d(1536, 256, kernel_size=1, stride=1) - - self.branch1_0 = BasicConv2d(1536, 384, kernel_size=1, stride=1) - self.branch1_1a = BasicConv2d( - 384, 256, kernel_size=(1, 3), stride=1, padding=(0, 1) - ) - self.branch1_1b = BasicConv2d( - 384, 256, kernel_size=(3, 1), stride=1, padding=(1, 0) - ) - - self.branch2_0 = BasicConv2d(1536, 384, kernel_size=1, stride=1) - self.branch2_1 = BasicConv2d( - 384, 448, kernel_size=(3, 1), stride=1, padding=(1, 0) - ) - self.branch2_2 = BasicConv2d( - 448, 512, kernel_size=(1, 3), stride=1, padding=(0, 1) - ) - self.branch2_3a = BasicConv2d( - 512, 256, kernel_size=(1, 3), stride=1, padding=(0, 1) - ) - self.branch2_3b = BasicConv2d( - 512, 256, kernel_size=(3, 1), stride=1, padding=(1, 0) - ) - - self.branch3 = nn.Sequential( - nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False), - BasicConv2d(1536, 256, kernel_size=1, stride=1), - ) - - def forward(self, x): - x0 = self.branch0(x) - - x1_0 = self.branch1_0(x) - x1_1a = self.branch1_1a(x1_0) - x1_1b = self.branch1_1b(x1_0) - x1 = torch.cat((x1_1a, x1_1b), 1) - - x2_0 = self.branch2_0(x) - x2_1 = self.branch2_1(x2_0) - x2_2 = self.branch2_2(x2_1) - x2_3a = self.branch2_3a(x2_2) - x2_3b = self.branch2_3b(x2_2) - x2 = torch.cat((x2_3a, x2_3b), 1) - - x3 = self.branch3(x) - - out = torch.cat((x0, x1, x2, x3), 1) - return out - - -class InceptionV4(nn.Module): - def __init__(self, num_classes=1001): - super(InceptionV4, self).__init__() - # Special attributs - self.input_space = None - self.input_size = (299, 299, 3) - self.mean = None - self.std = None - # Modules - self.features = nn.Sequential( - BasicConv2d(3, 32, kernel_size=3, stride=2), - BasicConv2d(32, 32, kernel_size=3, stride=1), - BasicConv2d(32, 64, kernel_size=3, stride=1, padding=1), - Mixed_3a(), - Mixed_4a(), - Mixed_5a(), - Inception_A(), - Inception_A(), - Inception_A(), - Inception_A(), - Reduction_A(), # Mixed_6a - Inception_B(), - Inception_B(), - Inception_B(), - Inception_B(), - Inception_B(), - Inception_B(), - Inception_B(), - Reduction_B(), # Mixed_7a - Inception_C(), - Inception_C(), - Inception_C(), - ) - self.last_linear = nn.Linear(1536, num_classes) - - def logits(self, features): - # Allows image of any size to be processed - adaptiveAvgPoolWidth = features.shape[2] - x = F.avg_pool2d(features, kernel_size=adaptiveAvgPoolWidth) - x = x.view(x.size(0), -1) - x = self.last_linear(x) - return x - - def forward(self, input): - x = self.features(input) - x = self.logits(x) - return x - - -def inceptionv4(num_classes: int): - return InceptionV4(num_classes=num_classes) diff --git a/wsinfer/_modellib/inceptionv4_no_batchnorm.py b/wsinfer/_modellib/inceptionv4_no_batchnorm.py index dc4f934..2fb0bdc 100644 --- a/wsinfer/_modellib/inceptionv4_no_batchnorm.py +++ b/wsinfer/_modellib/inceptionv4_no_batchnorm.py @@ -1,21 +1,22 @@ -# BSD 3-Clause License -# -# Copyright (c) 2017, Remi Cadene -# All rights reserved. -# -# Downloaded from -# https://raw.githubusercontent.com/Cadene/pretrained-models.pytorch/e07fb68c317880e780eb5ca9c20cca00f2584878/pretrainedmodels/models/inceptionv4.py # noqa: E501 -# -# We downloaded this file here so we did not have to add pretrainedmodels as a -# dependency (we only use this module). -# -# Modified to not use batchnorm. Models trained with TF Slim do not use batchnorm. +# https://raw.githubusercontent.com/rwightman/pytorch-image-models/e9aac412de82310e6905992e802b1ee4dc52b5d1/timm/models/inception_v4.py +""" +Pytorch Inception-V4 implementation +Sourced from https://github.com/Cadene/tensorflow-model-zoo.torch (MIT License) which is +based upon Google's Tensorflow implementation and pretrained weights +(Apache 2.0 License). + +This source was copied into the wsinfer source code and modified to remove batchnorm. +Bias terms are added wherever batchnorm is removed. +""" import torch import torch.nn as nn import torch.nn.functional as F -__all__ = ["InceptionV4", "inceptionv4"] +# from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD +from timm.models import register_model +from timm.models.helpers import build_model_with_cfg +from timm.models.layers import create_classifier class BasicConv2d(nn.Module): @@ -27,19 +28,21 @@ def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0): kernel_size=kernel_size, stride=stride, padding=padding, - bias=True, # Changed this to True after removing batchnorm. + bias=True, # Set to True after removing BatchNorm. ) + # self.bn = nn.BatchNorm2d(out_planes, eps=0.001) self.relu = nn.ReLU(inplace=True) def forward(self, x): x = self.conv(x) + # x = self.bn(x) x = self.relu(x) return x -class Mixed_3a(nn.Module): +class Mixed3a(nn.Module): def __init__(self): - super(Mixed_3a, self).__init__() + super(Mixed3a, self).__init__() self.maxpool = nn.MaxPool2d(3, stride=2) self.conv = BasicConv2d(64, 96, kernel_size=3, stride=2) @@ -50,9 +53,9 @@ def forward(self, x): return out -class Mixed_4a(nn.Module): +class Mixed4a(nn.Module): def __init__(self): - super(Mixed_4a, self).__init__() + super(Mixed4a, self).__init__() self.branch0 = nn.Sequential( BasicConv2d(160, 64, kernel_size=1, stride=1), @@ -73,9 +76,9 @@ def forward(self, x): return out -class Mixed_5a(nn.Module): +class Mixed5a(nn.Module): def __init__(self): - super(Mixed_5a, self).__init__() + super(Mixed5a, self).__init__() self.conv = BasicConv2d(192, 192, kernel_size=3, stride=2) self.maxpool = nn.MaxPool2d(3, stride=2) @@ -86,9 +89,9 @@ def forward(self, x): return out -class Inception_A(nn.Module): +class InceptionA(nn.Module): def __init__(self): - super(Inception_A, self).__init__() + super(InceptionA, self).__init__() self.branch0 = BasicConv2d(384, 96, kernel_size=1, stride=1) self.branch1 = nn.Sequential( @@ -116,9 +119,9 @@ def forward(self, x): return out -class Reduction_A(nn.Module): +class ReductionA(nn.Module): def __init__(self): - super(Reduction_A, self).__init__() + super(ReductionA, self).__init__() self.branch0 = BasicConv2d(384, 384, kernel_size=3, stride=2) self.branch1 = nn.Sequential( @@ -137,9 +140,9 @@ def forward(self, x): return out -class Inception_B(nn.Module): +class InceptionB(nn.Module): def __init__(self): - super(Inception_B, self).__init__() + super(InceptionB, self).__init__() self.branch0 = BasicConv2d(1024, 384, kernel_size=1, stride=1) self.branch1 = nn.Sequential( @@ -170,9 +173,9 @@ def forward(self, x): return out -class Reduction_B(nn.Module): +class ReductionB(nn.Module): def __init__(self): - super(Reduction_B, self).__init__() + super(ReductionB, self).__init__() self.branch0 = nn.Sequential( BasicConv2d(1024, 192, kernel_size=1, stride=1), @@ -196,9 +199,9 @@ def forward(self, x): return out -class Inception_C(nn.Module): +class InceptionC(nn.Module): def __init__(self): - super(Inception_C, self).__init__() + super(InceptionC, self).__init__() self.branch0 = BasicConv2d(1536, 256, kernel_size=1, stride=1) @@ -251,53 +254,98 @@ def forward(self, x): class InceptionV4(nn.Module): - def __init__(self, num_classes=1001): + def __init__( + self, + num_classes=1000, + in_chans=3, + output_stride=32, + drop_rate=0.0, + global_pool="avg", + ): super(InceptionV4, self).__init__() - # Special attributs - self.input_space = None - self.input_size = (299, 299, 3) - self.mean = None - self.std = None - # Modules + assert output_stride == 32 + self.drop_rate = drop_rate + self.num_classes = num_classes + self.num_features = 1536 + self.features = nn.Sequential( - BasicConv2d(3, 32, kernel_size=3, stride=2), + BasicConv2d(in_chans, 32, kernel_size=3, stride=2), BasicConv2d(32, 32, kernel_size=3, stride=1), BasicConv2d(32, 64, kernel_size=3, stride=1, padding=1), - Mixed_3a(), - Mixed_4a(), - Mixed_5a(), - Inception_A(), - Inception_A(), - Inception_A(), - Inception_A(), - Reduction_A(), # Mixed_6a - Inception_B(), - Inception_B(), - Inception_B(), - Inception_B(), - Inception_B(), - Inception_B(), - Inception_B(), - Reduction_B(), # Mixed_7a - Inception_C(), - Inception_C(), - Inception_C(), + Mixed3a(), + Mixed4a(), + Mixed5a(), + InceptionA(), + InceptionA(), + InceptionA(), + InceptionA(), + ReductionA(), # Mixed6a + InceptionB(), + InceptionB(), + InceptionB(), + InceptionB(), + InceptionB(), + InceptionB(), + InceptionB(), + ReductionB(), # Mixed7a + InceptionC(), + InceptionC(), + InceptionC(), ) - self.last_linear = nn.Linear(1536, num_classes) - - def logits(self, features): - # Allows image of any size to be processed - adaptiveAvgPoolWidth = features.shape[2] - x = F.avg_pool2d(features, kernel_size=adaptiveAvgPoolWidth) - x = x.view(x.size(0), -1) - x = self.last_linear(x) - return x + self.feature_info = [ + dict(num_chs=64, reduction=2, module="features.2"), + dict(num_chs=160, reduction=4, module="features.3"), + dict(num_chs=384, reduction=8, module="features.9"), + dict(num_chs=1024, reduction=16, module="features.17"), + dict(num_chs=1536, reduction=32, module="features.21"), + ] + self.global_pool, self.last_linear = create_classifier( + self.num_features, self.num_classes, pool_type=global_pool + ) + + @torch.jit.ignore + def group_matcher(self, coarse=False): + return dict(stem=r"^features\.[012]\.", blocks=r"^features\.(\d+)") + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + assert not enable, "gradient checkpointing not supported" + + @torch.jit.ignore + def get_classifier(self): + return self.last_linear - def forward(self, input): - x = self.features(input) - x = self.logits(x) + def reset_classifier(self, num_classes, global_pool="avg"): + self.num_classes = num_classes + self.global_pool, self.last_linear = create_classifier( + self.num_features, self.num_classes, pool_type=global_pool + ) + + def forward_features(self, x): + return self.features(x) + + def forward_head(self, x, pre_logits: bool = False): + x = self.global_pool(x) + if self.drop_rate > 0: + x = F.dropout(x, p=self.drop_rate, training=self.training) + return x if pre_logits else self.last_linear(x) + + def forward(self, x): + x = self.forward_features(x) + x = self.forward_head(x) return x -def inceptionv4(num_classes=1000): - return InceptionV4(num_classes=num_classes) +def _create_inception_v4(variant, pretrained=False, **kwargs): + return build_model_with_cfg( + InceptionV4, + variant, + pretrained, + feature_cfg=dict(flatten_sequential=True), + **kwargs + ) + + +@register_model +def inception_v4nobn(pretrained=False, **kwargs): + return _create_inception_v4("inception_v4nobn", pretrained, **kwargs) diff --git a/wsinfer/_modellib/models.py b/wsinfer/_modellib/models.py index da44d25..f2692bc 100644 --- a/wsinfer/_modellib/models.py +++ b/wsinfer/_modellib/models.py @@ -15,8 +15,8 @@ from torch.hub import load_state_dict_from_url import yaml -from .inceptionv4 import inceptionv4 as _inceptionv4 -from .inceptionv4_no_batchnorm import inceptionv4 as _inceptionv4_no_bn +# Imported for side effects of registering model. +from . import inceptionv4_no_batchnorm as _ # noqa from .resnet_preact import resnet34_preact as _resnet34_preact from .vgg16mod import vgg16mod as _vgg16mod from .transforms import PatchClassification @@ -233,8 +233,6 @@ def get_sha256_of_weights(self) -> str: # Container for all models we can use that are not in timm. _model_registry: Dict[str, Callable[[int], torch.nn.Module]] = { - "inceptionv4": _inceptionv4, - "inceptionv4nobn": _inceptionv4_no_bn, "preactresnet34": _resnet34_preact, "vgg16mod": _vgg16mod, } diff --git a/wsinfer/modeldefs/inceptionv4_tcga-brca-v1.yaml b/wsinfer/modeldefs/inceptionv4_tcga-brca-v1.yaml index 2aad5f2..a24d779 100644 --- a/wsinfer/modeldefs/inceptionv4_tcga-brca-v1.yaml +++ b/wsinfer/modeldefs/inceptionv4_tcga-brca-v1.yaml @@ -3,7 +3,7 @@ version: "1.0" # The models are referenced by the pair of [architecture, weights], so this pair must # be unique. -architecture: inceptionv4 # Must be a string. +architecture: inception_v4 # Must be a string. name: TCGA-BRCA-v1 # Must be a string. # Where to get the model weights. Either a URL or path to a file. # If using a URL, set the url_file_name (the name of the file when it is downloaded). diff --git a/wsinfer/modeldefs/inceptionv4nobn_tcga-tils-v1.yaml b/wsinfer/modeldefs/inceptionv4nobn_tcga-tils-v1.yaml index 2991dfa..f542c29 100644 --- a/wsinfer/modeldefs/inceptionv4nobn_tcga-tils-v1.yaml +++ b/wsinfer/modeldefs/inceptionv4nobn_tcga-tils-v1.yaml @@ -4,7 +4,7 @@ version: "1.0" # The models are referenced by the pair of [architecture, weights], so this pair must # be unique. # Inceptionv4 without batch normalization. -architecture: inceptionv4nobn # Must be a string. +architecture: inception_v4nobn # Must be a string. name: TCGA-TILs-v1 # Must be a string. # Where to get the model weights. Either a URL or path to a file. # If using a URL, set the url_file_name (the name of the file when it is downloaded).