Skip to content

Commit

Permalink
Add activation parameter to ResNet (#7749)
Browse files Browse the repository at this point in the history
Fixes #7653 .

### Description
Includes an `act` parameter to `ResNet` and its submodules to allow for
passing the `inplace` param.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [x] New tests added to cover the changes.
- [x] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [x] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [x] In-line docstrings updated.
- [x] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: Peter Kaplinsky <[email protected]>
Co-authored-by: Peter Kaplinsky <[email protected]>
  • Loading branch information
Pkaps25 and Peter Kaplinsky authored May 9, 2024
1 parent d83fa56 commit 258f56d
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 12 deletions.
28 changes: 17 additions & 11 deletions monai/networks/nets/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from monai.networks.blocks.encoder import BaseEncoder
from monai.networks.layers.factories import Conv, Norm, Pool
from monai.networks.layers.utils import get_pool_layer
from monai.networks.layers.utils import get_act_layer, get_pool_layer
from monai.utils import ensure_tuple_rep
from monai.utils.module import look_up_option, optional_import

Expand Down Expand Up @@ -78,6 +78,7 @@ def __init__(
spatial_dims: int = 3,
stride: int = 1,
downsample: nn.Module | partial | None = None,
act: str | tuple = ("relu", {"inplace": True}),
) -> None:
"""
Args:
Expand All @@ -86,6 +87,7 @@ def __init__(
spatial_dims: number of spatial dimensions of the input image.
stride: stride to use for first conv layer.
downsample: which downsample layer to use.
act: activation type and arguments. Defaults to relu.
"""
super().__init__()

Expand All @@ -94,7 +96,7 @@ def __init__(

self.conv1 = conv_type(in_planes, planes, kernel_size=3, padding=1, stride=stride, bias=False)
self.bn1 = norm_type(planes)
self.relu = nn.ReLU(inplace=True)
self.act = get_act_layer(name=act)
self.conv2 = conv_type(planes, planes, kernel_size=3, padding=1, bias=False)
self.bn2 = norm_type(planes)
self.downsample = downsample
Expand All @@ -105,7 +107,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:

out: torch.Tensor = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.act(out)

out = self.conv2(out)
out = self.bn2(out)
Expand All @@ -114,7 +116,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
residual = self.downsample(x)

out += residual
out = self.relu(out)
out = self.act(out)

return out

Expand All @@ -129,6 +131,7 @@ def __init__(
spatial_dims: int = 3,
stride: int = 1,
downsample: nn.Module | partial | None = None,
act: str | tuple = ("relu", {"inplace": True}),
) -> None:
"""
Args:
Expand All @@ -137,6 +140,7 @@ def __init__(
spatial_dims: number of spatial dimensions of the input image.
stride: stride to use for second conv layer.
downsample: which downsample layer to use.
act: activation type and arguments. Defaults to relu.
"""

super().__init__()
Expand All @@ -150,7 +154,7 @@ def __init__(
self.bn2 = norm_type(planes)
self.conv3 = conv_type(planes, planes * self.expansion, kernel_size=1, bias=False)
self.bn3 = norm_type(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.act = get_act_layer(name=act)
self.downsample = downsample
self.stride = stride

Expand All @@ -159,11 +163,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:

out: torch.Tensor = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.act(out)

out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.act(out)

out = self.conv3(out)
out = self.bn3(out)
Expand All @@ -172,7 +176,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
residual = self.downsample(x)

out += residual
out = self.relu(out)
out = self.act(out)

return out

Expand Down Expand Up @@ -202,6 +206,7 @@ class ResNet(nn.Module):
num_classes: number of output (classifications).
feed_forward: whether to add the FC layer for the output, default to `True`.
bias_downsample: whether to use bias term in the downsampling block when `shortcut_type` is 'B', default to `True`.
act: activation type and arguments. Defaults to relu.
"""

Expand All @@ -220,6 +225,7 @@ def __init__(
num_classes: int = 400,
feed_forward: bool = True,
bias_downsample: bool = True, # for backwards compatibility (also see PR #5477)
act: str | tuple = ("relu", {"inplace": True}),
) -> None:
super().__init__()

Expand Down Expand Up @@ -257,7 +263,7 @@ def __init__(
bias=False,
)
self.bn1 = norm_type(self.in_planes)
self.relu = nn.ReLU(inplace=True)
self.act = get_act_layer(name=act)
self.maxpool = pool_type(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, block_inplanes[0], layers[0], spatial_dims, shortcut_type)
self.layer2 = self._make_layer(block, block_inplanes[1], layers[1], spatial_dims, shortcut_type, stride=2)
Expand Down Expand Up @@ -329,7 +335,7 @@ def _make_layer(
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.act(x)
if not self.no_max_pool:
x = self.maxpool(x)

Expand Down Expand Up @@ -396,7 +402,7 @@ def forward(self, inputs: torch.Tensor):
"""
x = self.conv1(inputs)
x = self.bn1(x)
x = self.relu(x)
x = self.act(x)

features = []
features.append(x)
Expand Down
19 changes: 18 additions & 1 deletion tests/test_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@
"num_classes": 3,
"conv1_t_size": [3],
"conv1_t_stride": 1,
"act": ("relu", {"inplace": False}),
},
(1, 2, 32),
(1, 3),
Expand Down Expand Up @@ -185,13 +186,29 @@
(1, 3),
]

TEST_CASE_8 = [
{
"block": "bottleneck",
"layers": [3, 4, 6, 3],
"block_inplanes": [64, 128, 256, 512],
"spatial_dims": 1,
"n_input_channels": 2,
"num_classes": 3,
"conv1_t_size": [3],
"conv1_t_stride": 1,
"act": ("relu", {"inplace": False}),
},
(1, 2, 32),
(1, 3),
]

TEST_CASES = []
PRETRAINED_TEST_CASES = []
for case in [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_2_A, TEST_CASE_3_A]:
for model in [resnet10, resnet18, resnet34, resnet50, resnet101, resnet152, resnet200]:
TEST_CASES.append([model, *case])
PRETRAINED_TEST_CASES.append([model, *case])
for case in [TEST_CASE_5, TEST_CASE_5_A, TEST_CASE_6, TEST_CASE_7]:
for case in [TEST_CASE_5, TEST_CASE_5_A, TEST_CASE_6, TEST_CASE_7, TEST_CASE_8]:
TEST_CASES.append([ResNet, *case])

TEST_SCRIPT_CASES = [
Expand Down

0 comments on commit 258f56d

Please sign in to comment.