From 09bafaea5b8cd974140da0ca4cf4453127d90019 Mon Sep 17 00:00:00 2001 From: Joseph Paul Cohen Date: Thu, 2 Jan 2025 17:58:42 -0800 Subject: [PATCH] Change automatic upsample to interpolate to match skimage preprocessing (#161) * change to interpolate and refactor --- tests/test_models.py | 8 +-- .../baseline_models/chestx_det/__init__.py | 10 ++- .../baseline_models/chexpert/__init__.py | 6 +- .../baseline_models/emory_hiti/__init__.py | 12 ++-- .../baseline_models/jfhealthcare/__init__.py | 6 +- .../baseline_models/riken/__init__.py | 11 ++-- .../baseline_models/xinario/__init__.py | 11 ++-- torchxrayvision/models.py | 61 +++---------------- torchxrayvision/utils.py | 40 ++++++++++++ 9 files changed, 77 insertions(+), 88 deletions(-) diff --git a/tests/test_models.py b/tests/test_models.py index fbb9129..d951f44 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -90,14 +90,14 @@ def test_normalization_check(): # so here the first 2 pixels are set to the limits test_x[0][0][0] = ra[0] test_x[0][0][1] = ra[1] - xrv.models.warning_log = {} + xrv.utils.warning_log = {} model(test_x) - assert xrv.models.warning_log['norm_correct'] == False, ra + assert xrv.utils.warning_log['norm_correct'] == False, ra for ra in correct_ranges: test_x = torch.zeros([1,1,224,224]) test_x.uniform_(ra[0], ra[1]) - xrv.models.warning_log = {} + xrv.utils.warning_log = {} model(test_x) - assert xrv.models.warning_log['norm_correct'] == True, ra + assert xrv.utils.warning_log['norm_correct'] == True, ra diff --git a/torchxrayvision/baseline_models/chestx_det/__init__.py b/torchxrayvision/baseline_models/chestx_det/__init__.py index cee1f21..2d1f790 100644 --- a/torchxrayvision/baseline_models/chestx_det/__init__.py +++ b/torchxrayvision/baseline_models/chestx_det/__init__.py @@ -105,15 +105,13 @@ def __init__(self, cache_dir:str = None): model.eval() self.model = model - self.upsample = nn.Upsample( - size=(512, 512), - mode='bilinear', - align_corners=False, - ) def forward(self, x): + x = x.repeat(1, 3, 1, 1) - x = self.upsample(x) + + x = utils.fix_resolution(x, 512, self) + utils.warn_normalization(x) # expecting values between [-1024,1024] x = (x + 1024) / 2048 diff --git a/torchxrayvision/baseline_models/chexpert/__init__.py b/torchxrayvision/baseline_models/chexpert/__init__.py index 678d8c0..dcf03f0 100644 --- a/torchxrayvision/baseline_models/chexpert/__init__.py +++ b/torchxrayvision/baseline_models/chexpert/__init__.py @@ -6,6 +6,7 @@ import torch import torch.nn as nn from .model import Tasks2Models +from ... import utils class DenseNet(nn.Module): @@ -52,8 +53,6 @@ def __init__(self, weights_zip="", num_models=30): dynamic=False, use_gpu=self.use_gpu) - self.upsample = nn.Upsample(size=(320, 320), mode='bilinear', align_corners=False) - self.pathologies = self.targets def forward(self, x): @@ -80,7 +79,8 @@ def forward(self, x): def features(self, x): x = x.repeat(1, 3, 1, 1) - x = self.upsample(x) + x = utils.fix_resolution(x, 320, self) + utils.warn_normalization(x) # expecting values between [-1024,1024] x = x / 512 diff --git a/torchxrayvision/baseline_models/emory_hiti/__init__.py b/torchxrayvision/baseline_models/emory_hiti/__init__.py index 3048a82..f30e7f9 100644 --- a/torchxrayvision/baseline_models/emory_hiti/__init__.py +++ b/torchxrayvision/baseline_models/emory_hiti/__init__.py @@ -7,7 +7,7 @@ import torch.nn as nn import torchvision import torchxrayvision as xrv - +from ... import utils class RaceModel(nn.Module): """This model is from the work below and is trained to predict the @@ -78,12 +78,6 @@ def __init__(self): print("Loading failure. Check weights file:", self.weights_filename_local) raise e - self.upsample = nn.Upsample( - size=(320, 320), - mode='bilinear', - align_corners=False, - ) - self.targets = ["Asian", "Black", "White"] self.mean = np.array([0.485, 0.456, 0.406]) @@ -93,7 +87,9 @@ def __init__(self): def forward(self, x): x = x.repeat(1, 3, 1, 1) - x = self.upsample(x) + + x = utils.fix_resolution(x, 320, self) + utils.warn_normalization(x) # Expecting values between [-1024,1024] x = (x + 1024) / 2048 diff --git a/torchxrayvision/baseline_models/jfhealthcare/__init__.py b/torchxrayvision/baseline_models/jfhealthcare/__init__.py index 8586525..6456c54 100644 --- a/torchxrayvision/baseline_models/jfhealthcare/__init__.py +++ b/torchxrayvision/baseline_models/jfhealthcare/__init__.py @@ -9,6 +9,7 @@ import pathlib import torch import torch.nn as nn +from ... import utils class DenseNet(nn.Module): @@ -76,13 +77,14 @@ def __init__(self, **entries): raise (e) self.model = model - self.upsample = nn.Upsample(size=(512, 512), mode='bilinear', align_corners=False) self.pathologies = self.targets def forward(self, x): x = x.repeat(1, 3, 1, 1) - x = self.upsample(x) + + x = utils.fix_resolution(x, 512, self) + utils.warn_normalization(x) # expecting values between [-1024,1024] x = x / 512 diff --git a/torchxrayvision/baseline_models/riken/__init__.py b/torchxrayvision/baseline_models/riken/__init__.py index 409930f..c567a62 100644 --- a/torchxrayvision/baseline_models/riken/__init__.py +++ b/torchxrayvision/baseline_models/riken/__init__.py @@ -6,6 +6,7 @@ import torchvision import pathlib import torchxrayvision as xrv +from ... import utils class AgeModel(nn.Module): @@ -71,12 +72,6 @@ def __init__(self): print("Loading failure. Check weights file:", self.weights_filename_local) raise e - self.upsample = nn.Upsample( - size=(320, 320), - mode='bilinear', - align_corners=False, - ) - self.norm = torchvision.transforms.Normalize( [0.485, 0.456, 0.406], [0.229, 0.224, 0.225], @@ -84,7 +79,9 @@ def __init__(self): def forward(self, x): x = x.repeat(1, 3, 1, 1) - x = self.upsample(x) + + x = utils.fix_resolution(x, 320, self) + utils.warn_normalization(x) # expecting values between [-1024,1024] x = (x + 1024) / 2048 diff --git a/torchxrayvision/baseline_models/xinario/__init__.py b/torchxrayvision/baseline_models/xinario/__init__.py index e8e90e0..3922df6 100644 --- a/torchxrayvision/baseline_models/xinario/__init__.py +++ b/torchxrayvision/baseline_models/xinario/__init__.py @@ -6,6 +6,7 @@ import torchvision import pathlib import torchxrayvision as xrv +from ... import utils class ViewModel(nn.Module): @@ -63,12 +64,6 @@ def __init__(self): print("Loading failure. Check weights file:", self.weights_filename_local) raise e - self.upsample = nn.Upsample( - size=(224, 224), - mode='bilinear', - align_corners=False, - ) - self.norm = torchvision.transforms.Normalize( [0.485, 0.456, 0.406], [0.229, 0.224, 0.225], @@ -76,7 +71,9 @@ def __init__(self): def forward(self, x): x = x.repeat(1, 3, 1, 1) - x = self.upsample(x) + + x = utils.fix_resolution(x, 224, self) + utils.warn_normalization(x) # expecting values between [-1024,1024] x = (x + 1024) / 2048 diff --git a/torchxrayvision/models.py b/torchxrayvision/models.py index 6f26750..4e75197 100644 --- a/torchxrayvision/models.py +++ b/torchxrayvision/models.py @@ -297,7 +297,7 @@ def __init__(self, self.weights_filename_local = get_weights(weights, cache_dir) try: - savedmodel = torch.load(self.weights_filename_local, map_location='cpu') + savedmodel = torch.load(self.weights_filename_local, map_location='cpu', weights_only=False) # patch to load old models https://github.com/pytorch/pytorch/issues/42242 for mod in savedmodel.modules(): if not hasattr(mod, "_non_persistent_buffers_set"): @@ -313,8 +313,6 @@ def __init__(self, if "op_threshs" in model_urls[weights]: self.op_threshs = torch.tensor(model_urls[weights]["op_threshs"]) - self.upsample = nn.Upsample(size=(224, 224), mode='bilinear', align_corners=False) - def __repr__(self): if self.weights is not None: return "XRV-DenseNet121-{}".format(self.weights) @@ -322,8 +320,8 @@ def __repr__(self): return "XRV-DenseNet" def features2(self, x): - x = fix_resolution(x, 224, self) - warn_normalization(x) + x = utils.fix_resolution(x, 224, self) + utils.warn_normalization(x) features = self.features(x) out = F.relu(features, inplace=True) @@ -331,7 +329,8 @@ def features2(self, x): return out def forward(self, x): - x = fix_resolution(x, 224, self) + x = utils.fix_resolution(x, 224, self) + utils.warn_normalization(x) features = self.features2(x) out = self.classifier(features) @@ -412,7 +411,7 @@ def __init__(self, weights: str = None, apply_sigmoid: bool = False, cache_dir: self.model.conv1 = torch.nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False) try: - self.model.load_state_dict(torch.load(self.weights_filename_local)) + self.model.load_state_dict(torch.load(self.weights_filename_local, map_location='cpu', weights_only=False)) except Exception as e: print("Loading failure. Check weights file:", self.weights_filename_local) raise e @@ -420,8 +419,6 @@ def __init__(self, weights: str = None, apply_sigmoid: bool = False, cache_dir: if "op_threshs" in model_urls[weights]: self.register_buffer('op_threshs', torch.tensor(model_urls[weights]["op_threshs"])) - self.upsample = nn.Upsample(size=(512, 512), mode='bilinear', align_corners=False) - self.eval() def __repr__(self): @@ -431,8 +428,8 @@ def __repr__(self): return "XRV-ResNet" def features(self, x): - x = fix_resolution(x, 512, self) - warn_normalization(x) + x = utils.fix_resolution(x, 512, self) + utils.warn_normalization(x) x = self.model.conv1(x) x = self.model.bn1(x) @@ -449,8 +446,8 @@ def features(self, x): return x def forward(self, x): - x = fix_resolution(x, 512, self) - warn_normalization(x) + x = utils.fix_resolution(x, 512, self) + utils.warn_normalization(x) out = self.model(x) @@ -463,44 +460,6 @@ def forward(self, x): return out -warning_log = {} - - -def fix_resolution(x, resolution: int, model: nn.Module): - """Check resolution of input and resize to match requested.""" - - # just skip it if upsample was removed somehow - if not hasattr(model, 'upsample') or (model.upsample == None): - return x - - if (x.shape[2] != resolution) | (x.shape[3] != resolution): - if not hash(model) in warning_log: - print("Warning: Input size ({}x{}) is not the native resolution ({}x{}) for this model. A resize will be performed but this could impact performance.".format(x.shape[2], x.shape[3], resolution, resolution)) - warning_log[hash(model)] = True - return model.upsample(x) - return x - - -def warn_normalization(x): - """Check normalization of input and warn if possibly wrong. When - processing an image that may likely not have the correct - normalization we can issue a warning. But running min and max on - every image/batch is costly so we only do it on the first image/batch. - """ - - # Only run this check on the first image so we don't hurt performance. - if not "norm_check" in warning_log: - x_min = x.min() - x_max = x.max() - if torch.logical_or(-255 < x_min, x_max < 255) or torch.logical_or(x_min < -1024, 1024 < x_max): - print(f'Warning: Input image does not appear to be normalized correctly. The input image has the range [{x_min:.2f},{x_max:.2f}] which doesn\'t seem to be in the [-1024,1024] range. This warning may be wrong though. Only the first image is tested and we are only using a heuristic in an attempt to save a user from using the wrong normalization.') - warning_log["norm_correct"] = False - else: - warning_log["norm_correct"] = True - - warning_log["norm_check"] = True - - def op_norm(outputs, op_threshs): """Normalize outputs according to operating points for a given model. Args: diff --git a/torchxrayvision/utils.py b/torchxrayvision/utils.py index 8e6fa0f..0ea404e 100644 --- a/torchxrayvision/utils.py +++ b/torchxrayvision/utils.py @@ -144,3 +144,43 @@ def infer(model: torch.nn.Module, dataset: torch.utils.data.Dataset, threads=4, preds.append(output) return np.concatenate(preds) + + +warning_log = {} + +def fix_resolution(x, resolution: int, model): + """Check resolution of input and resize to match requested.""" + + if len(x.shape) == 3: + # Extend to be 4D + x = x[None,...] + + if x.shape[2] != x.shape[3]: + raise Exception(f"Height and width of the image must be the same. Input: {x.shape[2]} != {x.shape[3]}. Perform a center crop first.") + + if (x.shape[2] != resolution) | (x.shape[3] != resolution): + if not hash(model) in warning_log: + print("Warning: Input size ({}x{}) is not the native resolution ({}x{}) for this model. A resize will be performed but this could impact performance.".format(x.shape[2], x.shape[3], resolution, resolution)) + warning_log[hash(model)] = True + return torch.nn.functional.interpolate(x, size=(resolution, resolution), mode='bilinear', antialias=True) + return x + + +def warn_normalization(x): + """Check normalization of input and warn if possibly wrong. When + processing an image that may likely not have the correct + normalization we can issue a warning. But running min and max on + every image/batch is costly so we only do it on the first image/batch. + """ + + # Only run this check on the first image so we don't hurt performance. + if not "norm_check" in warning_log: + x_min = x.min() + x_max = x.max() + if torch.logical_or(-255 < x_min, x_max < 255) or torch.logical_or(x_min < -1025, 1025 < x_max): + print(f'Warning: Input image does not appear to be normalized correctly. The input image has the range [{x_min:.2f},{x_max:.2f}] which doesn\'t seem to be in the [-1024,1024] range. This warning may be wrong though. Only the first image is tested and we are only using a heuristic in an attempt to save a user from using the wrong normalization.') + warning_log["norm_correct"] = False + else: + warning_log["norm_correct"] = True + + warning_log["norm_check"] = True