Skip to content

Commit

Permalink
Enble swinunetr-v2 (Project-MONAI#6203)
Browse files Browse the repository at this point in the history
Fixes Project-MONAI#6183 .

### Description
Added a "use_v2" option in swinunetr initialization. Default is false
will not affect the original swinunetr.
Once changed to true, will become swinunetr-v2 with 4 additional
convolution block.
Tested running from auto3dseg bundles, no change needed for original
swinunetr, and works for swinunetr-v2
Tested running from monai research contribution repo for swinuntr, no
change needed for original swinunetr, and works for swinunetr-v2
Tested TensorRT, compiled .ts file successfully.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: heyufan1995 <[email protected]>
  • Loading branch information
heyufan1995 authored and jak0bw committed Mar 28, 2023
1 parent 901b7ec commit 62a73c8
Showing 1 changed file with 38 additions and 0 deletions.
38 changes: 38 additions & 0 deletions monai/networks/nets/swin_unetr.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def __init__(
use_checkpoint: bool = False,
spatial_dims: int = 3,
downsample="merging",
use_v2=False,
) -> None:
"""
Args:
Expand All @@ -84,6 +85,7 @@ def __init__(
downsample: module used for downsampling, available options are `"mergingv2"`, `"merging"` and a
user-specified `nn.Module` following the API defined in :py:class:`monai.networks.nets.PatchMerging`.
The default is currently `"merging"` (the original version defined in v0.9.0).
use_v2: using swinunetr_v2, which adds a residual convolution block at the beggining of each swin stage.
Examples::
Expand Down Expand Up @@ -142,6 +144,7 @@ def __init__(
use_checkpoint=use_checkpoint,
spatial_dims=spatial_dims,
downsample=look_up_option(downsample, MERGING_MODE) if isinstance(downsample, str) else downsample,
use_v2=use_v2,
)

self.encoder1 = UnetrBasicBlock(
Expand Down Expand Up @@ -921,6 +924,7 @@ def __init__(
use_checkpoint: bool = False,
spatial_dims: int = 3,
downsample="merging",
use_v2=False,
) -> None:
"""
Args:
Expand All @@ -942,6 +946,7 @@ def __init__(
downsample: module used for downsampling, available options are `"mergingv2"`, `"merging"` and a
user-specified `nn.Module` following the API defined in :py:class:`monai.networks.nets.PatchMerging`.
The default is currently `"merging"` (the original version defined in v0.9.0).
use_v2: using swinunetr_v2, which adds a residual convolution block at the beggining of each swin stage.
"""

super().__init__()
Expand All @@ -959,10 +964,16 @@ def __init__(
)
self.pos_drop = nn.Dropout(p=drop_rate)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
self.use_v2 = use_v2
self.layers1 = nn.ModuleList()
self.layers2 = nn.ModuleList()
self.layers3 = nn.ModuleList()
self.layers4 = nn.ModuleList()
if self.use_v2:
self.layers1c = nn.ModuleList()
self.layers2c = nn.ModuleList()
self.layers3c = nn.ModuleList()
self.layers4c = nn.ModuleList()
down_sample_mod = look_up_option(downsample, MERGING_MODE) if isinstance(downsample, str) else downsample
for i_layer in range(self.num_layers):
layer = BasicLayer(
Expand All @@ -987,6 +998,25 @@ def __init__(
self.layers3.append(layer)
elif i_layer == 3:
self.layers4.append(layer)
if self.use_v2:
layerc = UnetrBasicBlock(
spatial_dims=3,
in_channels=embed_dim * 2**i_layer,
out_channels=embed_dim * 2**i_layer,
kernel_size=3,
stride=1,
norm_name="instance",
res_block=True,
)
if i_layer == 0:
self.layers1c.append(layerc)
elif i_layer == 1:
self.layers2c.append(layerc)
elif i_layer == 2:
self.layers3c.append(layerc)
elif i_layer == 3:
self.layers4c.append(layerc)

self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))

def proj_out(self, x, normalize=False):
Expand All @@ -1008,12 +1038,20 @@ def forward(self, x, normalize=True):
x0 = self.patch_embed(x)
x0 = self.pos_drop(x0)
x0_out = self.proj_out(x0, normalize)
if self.use_v2:
x0 = self.layers1c[0](x0.contiguous())
x1 = self.layers1[0](x0.contiguous())
x1_out = self.proj_out(x1, normalize)
if self.use_v2:
x1 = self.layers2c[0](x1.contiguous())
x2 = self.layers2[0](x1.contiguous())
x2_out = self.proj_out(x2, normalize)
if self.use_v2:
x2 = self.layers3c[0](x2.contiguous())
x3 = self.layers3[0](x2.contiguous())
x3_out = self.proj_out(x3, normalize)
if self.use_v2:
x3 = self.layers4c[0](x3.contiguous())
x4 = self.layers4[0](x3.contiguous())
x4_out = self.proj_out(x4, normalize)
return [x0_out, x1_out, x2_out, x3_out, x4_out]

0 comments on commit 62a73c8

Please sign in to comment.