Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable swinunetr-v2 #6203

Merged
merged 12 commits into from
Mar 22, 2023
38 changes: 38 additions & 0 deletions monai/networks/nets/swin_unetr.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def __init__(
use_checkpoint: bool = False,
spatial_dims: int = 3,
downsample="merging",
use_v2=False,
) -> None:
"""
Args:
Expand All @@ -84,6 +85,7 @@ def __init__(
downsample: module used for downsampling, available options are `"mergingv2"`, `"merging"` and a
user-specified `nn.Module` following the API defined in :py:class:`monai.networks.nets.PatchMerging`.
The default is currently `"merging"` (the original version defined in v0.9.0).
use_v2: using swinunetr_v2, which adds a residual convolution block at the beggining of each swin stage.

Examples::

Expand Down Expand Up @@ -142,6 +144,7 @@ def __init__(
use_checkpoint=use_checkpoint,
spatial_dims=spatial_dims,
downsample=look_up_option(downsample, MERGING_MODE) if isinstance(downsample, str) else downsample,
use_v2=use_v2,
)

self.encoder1 = UnetrBasicBlock(
Expand Down Expand Up @@ -921,6 +924,7 @@ def __init__(
use_checkpoint: bool = False,
spatial_dims: int = 3,
downsample="merging",
use_v2=False,
) -> None:
"""
Args:
Expand All @@ -942,6 +946,7 @@ def __init__(
downsample: module used for downsampling, available options are `"mergingv2"`, `"merging"` and a
user-specified `nn.Module` following the API defined in :py:class:`monai.networks.nets.PatchMerging`.
The default is currently `"merging"` (the original version defined in v0.9.0).
use_v2: using swinunetr_v2, which adds a residual convolution block at the beggining of each swin stage.
"""

super().__init__()
Expand All @@ -959,10 +964,16 @@ def __init__(
)
self.pos_drop = nn.Dropout(p=drop_rate)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
self.use_v2 = use_v2
self.layers1 = nn.ModuleList()
self.layers2 = nn.ModuleList()
self.layers3 = nn.ModuleList()
self.layers4 = nn.ModuleList()
if self.use_v2:
self.layers1c = nn.ModuleList()
self.layers2c = nn.ModuleList()
self.layers3c = nn.ModuleList()
self.layers4c = nn.ModuleList()
down_sample_mod = look_up_option(downsample, MERGING_MODE) if isinstance(downsample, str) else downsample
for i_layer in range(self.num_layers):
layer = BasicLayer(
Expand All @@ -987,6 +998,25 @@ def __init__(
self.layers3.append(layer)
elif i_layer == 3:
self.layers4.append(layer)
if self.use_v2:
layerc = UnetrBasicBlock(
spatial_dims=3,
in_channels=embed_dim * 2**i_layer,
out_channels=embed_dim * 2**i_layer,
kernel_size=3,
stride=1,
norm_name="instance",
res_block=True,
)
if i_layer == 0:
self.layers1c.append(layerc)
elif i_layer == 1:
self.layers2c.append(layerc)
elif i_layer == 2:
self.layers3c.append(layerc)
elif i_layer == 3:
self.layers4c.append(layerc)

self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))

def proj_out(self, x, normalize=False):
Expand All @@ -1008,12 +1038,20 @@ def forward(self, x, normalize=True):
x0 = self.patch_embed(x)
x0 = self.pos_drop(x0)
x0_out = self.proj_out(x0, normalize)
if self.use_v2:
x0 = self.layers1c[0](x0.contiguous())
x1 = self.layers1[0](x0.contiguous())
x1_out = self.proj_out(x1, normalize)
if self.use_v2:
x1 = self.layers2c[0](x1.contiguous())
x2 = self.layers2[0](x1.contiguous())
x2_out = self.proj_out(x2, normalize)
if self.use_v2:
x2 = self.layers3c[0](x2.contiguous())
x3 = self.layers3[0](x2.contiguous())
x3_out = self.proj_out(x3, normalize)
if self.use_v2:
x3 = self.layers4c[0](x3.contiguous())
x4 = self.layers4[0](x3.contiguous())
x4_out = self.proj_out(x4, normalize)
return [x0_out, x1_out, x2_out, x3_out, x4_out]