Skip to content

Commit

Permalink
tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ffiirree committed Dec 11, 2021
1 parent 992fc58 commit f03fbc3
Show file tree
Hide file tree
Showing 8 changed files with 106 additions and 20 deletions.
13 changes: 8 additions & 5 deletions cvm/models/core/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,7 +669,7 @@ def __init__(
)


class SEBlock(nn.Module):
class SEBlock(nn.Sequential):
"""Squeeze excite block
"""

Expand All @@ -680,8 +680,6 @@ def __init__(
inner_activation_fn: nn.Module = None,
gating_fn: nn.Module = None
):
super().__init__()

squeezed_channels = make_divisible(int(channels * ratio), _SE_DIVISOR)
inner_activation_fn = inner_activation_fn or _SE_INNER_NONLINEAR
gating_fn = gating_fn or _SE_GATING_FN
Expand All @@ -696,10 +694,15 @@ def __init__(
layers['expand'] = Conv2d1x1(squeezed_channels, channels, bias=True)
layers['gate'] = gating_fn()

self.se = nn.Sequential(layers)
super().__init__(layers)

def _forward(self, input):
for module in self:
input = module(input)
return input

def forward(self, x):
return x * self.se(x)
return x * self._forward(x)


class ChannelChunk(nn.Module):
Expand Down
1 change: 0 additions & 1 deletion cvm/models/efficientnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
_BN_MOMENTUM = 0.01


@export
def efficientnet_params(model_name):
"""Get efficientnet params based on model name."""
params_dict = {
Expand Down
2 changes: 1 addition & 1 deletion cvm/models/regnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def regnet_x_32gf(pretrained: bool = False, pth: str = None, progress: bool = Tr

@export
def regnet_y_200mf(pretrained: bool = False, pth: str = None, progress: bool = True, **kwargs):
return _regnet(13, 24, 36,44, 2.49, 1.0, 8, 0.25, pretrained, pth, progress, **kwargs)
return _regnet(13, 24, 36.44, 2.49, 1.0, 8, 0.25, pretrained, pth, progress, **kwargs)


@export
Expand Down
15 changes: 4 additions & 11 deletions cvm/models/rexnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,8 @@ def __init__(
in_channels: int = 3,
num_classes: int = 1000,
dropout_rate: float = 0.2,
thumbnail: bool = False
thumbnail: bool = False,
**kwargs: Any
):
super().__init__()

Expand Down Expand Up @@ -188,15 +189,7 @@ def forward(self, x):
@export
def rexnet_plain(pretrained: bool = False, pth: str = None, progress: bool = True, **kwargs: Any):
model = ReXNetPlain(**kwargs)

if pretrained:
if pth is not None:
state_dict = torch.load(os.path.expanduser(pth))
else:
assert 'url' in kwargs and kwargs['url'] != '', 'Invalid URL.'
state_dict = torch.hub.load_state_dict_from_url(
kwargs['url'],
progress=progress
)
model.load_state_dict(state_dict)
load_from_local_or_url(model, pth, kwargs.get('url', None), progress)
return model
2 changes: 1 addition & 1 deletion cvm/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.0.8'
__version__ = '0.0.9'
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,5 @@
'tqdm',
# 'nvidia-dali-cuda110 >= 1.7'
],
packages=find_packages()
packages=find_packages(exclude=['tests'])
)
62 changes: 62 additions & 0 deletions tests/test_blocks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from functools import partial
import pytest
import torch
import torch.nn as nn
from cvm.models.core import blocks


def test_se_block_forward():
inputs = torch.randn(16, 3, 56, 56)

se = blocks.SEBlock(3, 0.25)

outputs = se(inputs)
assert outputs.shape == inputs.shape
assert isinstance(se.act, nn.ReLU)
assert isinstance(se.gate, nn.Sigmoid)


def test_se_block_decorator():
with blocks.se(inner_nonlinear=nn.SiLU, gating_fn=nn.Hardsigmoid):
se = blocks.SEBlock(3, 0.25)

assert isinstance(se.act, nn.SiLU)
assert isinstance(se.gate, nn.Hardsigmoid)


def test_normalizer_decorator():
with blocks.normalizer(None):
layers = blocks.norm_activation(3)

assert len(layers) == 1
assert isinstance(layers[0], nn.ReLU)

with blocks.normalizer(nn.LayerNorm, position='before'):
layers = blocks.norm_activation(3)

assert len(layers) == 2
assert isinstance(layers[0], nn.LayerNorm)
assert isinstance(layers[1], nn.ReLU)

with blocks.normalizer(partial(nn.BatchNorm2d, eps=0.1), position='after'):
layers = blocks.norm_activation(3)

assert len(layers) == 2
assert isinstance(layers[0], nn.ReLU)
assert isinstance(layers[1], nn.BatchNorm2d)
assert layers[1].eps == 0.1


def test_nonlinear_decorator():
with blocks.nonlinear(None):
layers = blocks.norm_activation(3)

assert len(layers) == 1
assert isinstance(layers[0], nn.BatchNorm2d)

with blocks.nonlinear(nn.SiLU):
layers = blocks.norm_activation(3)

assert len(layers) == 2
assert isinstance(layers[0], nn.BatchNorm2d)
assert isinstance(layers[1], nn.SiLU)
29 changes: 29 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import pytest
import torch
from cvm.models.core import SegmentationModel
from cvm.utils import list_models, create_model


@pytest.mark.parametrize('name', list_models('cvm'))
def test_model_forward(name):
model = create_model(
name,
dropout_rate=0.,
drop_path_rate=0.,
num_classes=10,
cuda=False
)

model.eval()

inputs = torch.randn((1, 3, 224, 224))
outputs = model(inputs)

if name in ['unet', 'vae', 'dcgan']:
...
elif isinstance(model, SegmentationModel):
assert outputs[0].shape == torch.Size([1, 10, 224, 224])
assert not torch.isnan(outputs[0]).any(), 'Output included NaNs'
else:
assert outputs.shape == torch.Size([1, 10])
assert not torch.isnan(outputs).any(), 'Output included NaNs'

0 comments on commit f03fbc3

Please sign in to comment.