import torch import torch.nn as nn import torch.nn.functional as F from pdb import set_trace as stx import numbers from functools import partial from src.mm.blockdiag_linear import BlockdiagLinear from einops import rearrange class FusionMDTA(nn.Module): def __init__(self, dim, num_heads): super(FusionMDTA, self).__init__() self.num_heads = num_heads self.temperature = nn.Parameter(torch.ones(1, num_heads, 1, 1)) self.q = nn.Conv2d(dim, dim, kernel_size=1, bias=False) self.q_conv = nn.Conv2d(dim, dim, kernel_size=3, padding=1, groups=dim, bias=False) self.fs_dim = int(dim * 0.2) kv_dim = self.fs_dim + dim self.kv = nn.Conv2d(kv_dim, kv_dim * 2, kernel_size=1, bias=False) self.kv_conv = nn.Conv2d(kv_dim * 2, kv_dim * 2, kernel_size=3, padding=1, groups=kv_dim * 2, bias=False) self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=False) # depth-wise seperable conv self.dwconv = nn.Sequential( nn.Conv2d(kv_dim, kv_dim, kernel_size=1, bias=False), nn.Conv2d(kv_dim, kv_dim, kernel_size=3, padding=1, groups=kv_dim, bias=False), nn.Conv2d(kv_dim, dim, kernel_size=1, bias=False) ) def forward(self, x): b, c, h, w = x.shape # deep feature q = x[:,self.fs_dim:c,:,:] q = self.q_conv(self.q(q)) k, v = self.kv_conv(self.kv(x)).chunk(2, dim=1) skip_v = v q = q.reshape(b, self.num_heads, -1, h * w) k = k.reshape(b, self.num_heads, -1, h * w) v = v.reshape(b, self.num_heads, -1, h * w) q, k = F.normalize(q, dim=-1), F.normalize(k, dim=-1) attn = torch.softmax(torch.matmul(q, k.transpose(-2, -1).contiguous()) * self.temperature, dim=-1) out = self.project_out(torch.matmul(attn, v).reshape(b, -1, h, w) + self.dwconv(skip_v)) return out ############################# class SKFF(nn.Module): def __init__(self, in_channels, height=2, reduction=8, bias=False): super(SKFF, self).__init__() self.height = height d = max(int(in_channels/reduction),4) self.avg_pool = nn.AdaptiveAvgPool2d(1) self.conv_du = nn.Sequential(nn.Conv2d(in_channels, d, 1, padding=0, bias=bias), nn.ReLU()) self.fcs = nn.ModuleList([]) for i in range(self.height): self.fcs.append(nn.Conv2d(d, in_channels, kernel_size=1, stride=1,bias=bias)) self.softmax = nn.Softmax(dim=1) def forward(self, inp_feats): batch_size = inp_feats[0].shape[0] n_feats = inp_feats[0].shape[1] inp_feats = torch.cat(inp_feats, dim=1) inp_feats = inp_feats.view(batch_size, self.height, n_feats, inp_feats.shape[2], inp_feats.shape[3]) feats_U = torch.sum(inp_feats, dim=1) feats_S = self.avg_pool(feats_U) feats_Z = self.conv_du(feats_S) attention_vectors = [fc(feats_Z) for fc in self.fcs] attention_vectors = torch.cat(attention_vectors, dim=1) attention_vectors = attention_vectors.view(batch_size, self.height, n_feats, 1, 1) attention_vectors = self.softmax(attention_vectors) feats_V = torch.sum(inp_feats*attention_vectors, dim=1) return feats_V ##################### def to_3d(x): return rearrange(x, 'b c h w -> b (h w) c') def to_4d(x, h, w): return rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) class BiasFree_LayerNorm(nn.Module): def __init__(self, normalized_shape): super(BiasFree_LayerNorm, self).__init__() if isinstance(normalized_shape, numbers.Integral): normalized_shape = (normalized_shape,) normalized_shape = torch.Size(normalized_shape) assert len(normalized_shape) == 1 self.weight = nn.Parameter(torch.ones(normalized_shape)) self.normalized_shape = normalized_shape def forward(self, x): sigma = x.var(-1, keepdim=True, unbiased=False) return x / torch.sqrt(sigma + 1e-5) * self.weight class WithBias_LayerNorm(nn.Module): def __init__(self, normalized_shape): super(WithBias_LayerNorm, self).__init__() if isinstance(normalized_shape, numbers.Integral): normalized_shape = (normalized_shape,) normalized_shape = torch.Size(normalized_shape) assert len(normalized_shape) == 1 self.weight = nn.Parameter(torch.ones(normalized_shape)) self.bias = nn.Parameter(torch.zeros(normalized_shape)) self.normalized_shape = normalized_shape def forward(self, x): mu = x.mean(-1, keepdim=True) sigma = x.var(-1, keepdim=True, unbiased=False) return (x - mu) / torch.sqrt(sigma + 1e-5) * self.weight + self.bias class LayerNorm(nn.Module): def __init__(self, dim, LayerNorm_type): super(LayerNorm, self).__init__() if LayerNorm_type == 'BiasFree': self.body = BiasFree_LayerNorm(dim) else: self.body = WithBias_LayerNorm(dim) def forward(self, x): h, w = x.shape[-2:] return to_4d(self.body(to_3d(x)), h, w) ######################### #M2MLP class FeedForward(nn.Module): """Applies the MLP.""" def __init__(self, dim, ffn_expansion_factor, bias): super().__init__() #self.config = config hidden_size = int(dim * ffn_expansion_factor) #if self.config.use_monarch_mlp: linear_cls = partial(BlockdiagLinear, nblocks=4) #else: #linear_cls = nn.Linear self.linear = linear_cls(dim, hidden_size, bias=bias) self.act = nn.GELU(approximate='none') self.wo = linear_cls(hidden_size, dim, bias=bias) #self.layernorm = nn.LayerNorm(config.hidden_size, #eps=config.layer_norm_eps) #self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward(self, x): """Compute new hidden states from current hidden states. Args: hidden_states (torch.Tensor): The (unpadded) hidden states from the attention layer [nnz, dim]. """ #residual_connection = hidden_states x = self.linear(x) x = self.act(x) #x = self.dropout(x) x = self.wo(x) #x = self.layernorm(hidden_states + residual_connection) return x ## Top-K Sparse Attention (TKSA) class Attention(nn.Module): def __init__(self, dim, num_heads, bias): super(Attention, self).__init__() self.num_heads = num_heads self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias) self.qkv_dwconv = nn.Conv2d(dim * 3, dim * 3, kernel_size=3, stride=1, padding=1, groups=dim * 3, bias=bias) self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) self.dwconv = nn.Sequential( nn.Conv2d(dim, dim, kernel_size=1, bias=bias), nn.Conv2d(dim, dim, kernel_size=3, padding=1, groups=dim, bias=bias), nn.Conv2d(dim, dim, kernel_size=1, bias=bias) ) self.attn_drop = nn.Dropout(0.) self.attn1 = torch.nn.Parameter(torch.tensor([0.2]), requires_grad=True) self.attn2 = torch.nn.Parameter(torch.tensor([0.2]), requires_grad=True) self.attn3 = torch.nn.Parameter(torch.tensor([0.2]), requires_grad=True) self.attn4 = torch.nn.Parameter(torch.tensor([0.2]), requires_grad=True) def forward(self, x): b, c, h, w = x.shape qkv = self.qkv_dwconv(self.qkv(x)) q, k, v = qkv.chunk(3, dim=1) skip_v = v q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads) k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads) v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads) q = torch.nn.functional.normalize(q, dim=-1) k = torch.nn.functional.normalize(k, dim=-1) _, _, C, _ = q.shape mask1 = torch.zeros(b, self.num_heads, C, C, device=x.device, requires_grad=False) mask2 = torch.zeros(b, self.num_heads, C, C, device=x.device, requires_grad=False) mask3 = torch.zeros(b, self.num_heads, C, C, device=x.device, requires_grad=False) mask4 = torch.zeros(b, self.num_heads, C, C, device=x.device, requires_grad=False) attn = (q @ k.transpose(-2, -1)) * self.temperature index = torch.topk(attn, k=int(C/2), dim=-1, largest=True)[1] mask1.scatter_(-1, index, 1.) attn1 = torch.where(mask1 > 0, attn, torch.full_like(attn, float('-inf'))) index = torch.topk(attn, k=int(C*2/3), dim=-1, largest=True)[1] mask2.scatter_(-1, index, 1.) attn2 = torch.where(mask2 > 0, attn, torch.full_like(attn, float('-inf'))) index = torch.topk(attn, k=int(C*3/4), dim=-1, largest=True)[1] mask3.scatter_(-1, index, 1.) attn3 = torch.where(mask3 > 0, attn, torch.full_like(attn, float('-inf'))) index = torch.topk(attn, k=int(C*4/5), dim=-1, largest=True)[1] mask4.scatter_(-1, index, 1.) attn4 = torch.where(mask4 > 0, attn, torch.full_like(attn, float('-inf'))) attn1 = attn1.softmax(dim=-1) attn2 = attn2.softmax(dim=-1) attn3 = attn3.softmax(dim=-1) attn4 = attn4.softmax(dim=-1) out1 = (attn1 @ v) out2 = (attn2 @ v) out3 = (attn3 @ v) out4 = (attn4 @ v) out = out1 * self.attn1 + out2 * self.attn2 + out3 * self.attn3 + out4 * self.attn4 out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w) out = out + self.dwconv(skip_v) out = self.project_out(out) return out ## Sparse Transformer Block (STB) class TransformerBlock(nn.Module): def __init__(self, dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type): super(TransformerBlock, self).__init__() self.norm1 = LayerNorm(dim, LayerNorm_type) self.attn = Attention(dim, num_heads, bias) self.norm2 = LayerNorm(dim, LayerNorm_type) self.ffn = FeedForward(dim, ffn_expansion_factor, bias) def forward(self, x): x = x + self.attn(self.norm1(x)) x = x + self.ffn(self.norm2(x)) return x ## Fusion Transformer Block class FusionTransformerBlock(nn.Module): def __init__(self, dim, num_heads, ffn_expansion_factor,bias): super(FusionTransformerBlock, self).__init__() self.c = dim self.fs_dim = int(dim * 0.2) self.norm1 = nn.LayerNorm(dim + self.fs_dim) self.attn = FusionMDTA(dim, num_heads) self.norm2 = nn.LayerNorm(dim) self.ffn = FeedForward(dim, ffn_expansion_factor, bias) def forward(self, x): b, c, h, w = x.shape x_norm = self.norm1(x.reshape(b, c, -1).transpose(-2, -1).contiguous()).transpose(-2, -1).contiguous().reshape(b, c, h, w) #print(x_norm.shape) #print(self.fs_dim) x = x[:,self.fs_dim:c,:,:] + self.attn(x_norm) x = x + self.ffn(self.norm2(x.reshape(b, self.c, -1).transpose(-2, -1).contiguous()).transpose(-2, -1).contiguous().reshape(b, self.c, h, w)) return x ## StageOne Transformer Block class StageOneBlock(nn.Module): def __init__(self, dim, num_heads, ffn_expansion_factor,bias,LayerNorm_type): super(StageOneBlock, self).__init__() fs_dim = int(dim * 0.2) self.shrink = nn.Conv2d(dim, fs_dim, kernel_size=1, bias=bias) self.blk1 = TransformerBlock(dim, num_heads, ffn_expansion_factor,bias,LayerNorm_type) self.blk2 = FusionTransformerBlock(dim, num_heads, ffn_expansion_factor,bias) def forward(self, x): #b, c, h, w = x.shape x_1 = self.blk1(x) x_shrink = self.shrink(x) x_2 = self.blk2(torch.cat([x_shrink,x_1],dim=1)) return x_2 ######################## #################################### ###################################### ## Overlapped image patch embedding with 3x3 Conv class OverlapPatchEmbed(nn.Module): def __init__(self, in_c=3, embed_dim=48, bias=False): super(OverlapPatchEmbed, self).__init__() self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias) def forward(self, x): x = self.proj(x) return x ## Resizing modules class Downsample(nn.Module): def __init__(self, n_feat): super(Downsample, self).__init__() self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat // 2, kernel_size=3, stride=1, padding=1, bias=False), nn.PixelUnshuffle(2)) def forward(self, x): return self.body(x) class Upsample(nn.Module): def __init__(self, n_feat): super(Upsample, self).__init__() self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat * 2, kernel_size=3, stride=1, padding=1, bias=False), nn.PixelShuffle(2)) def forward(self, x): return self.body(x) ###################### class DRSformer(nn.Module): def __init__(self, inp_channels=3, out_channels=3, dim=48, num_blocks=[4, 6, 6, 8], heads=[1, 2, 4, 8], ffn_expansion_factor=2.66, bias=False, LayerNorm_type='WithBias' ## Other option 'BiasFree' ): super(DRSformer, self).__init__() self.patch_embed = OverlapPatchEmbed(inp_channels, dim) #self.encoder_level0 = subnet(dim) ## We do not use MEFC for training Rain200L and SPA-Data self.encoder_level1 = nn.Sequential(*[ StageOneBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])]) self.down1_2 = Downsample(dim) ## From Level 1 to Level 2 self.encoder_level2 = nn.Sequential(*[ TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])]) self.down2_3 = Downsample(int(dim * 2 ** 1)) ## From Level 2 to Level 3 self.encoder_level3 = nn.Sequential(*[ TransformerBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])]) self.down3_4 = Downsample(int(dim * 2 ** 2)) ## From Level 3 to Level 4 self.latent = nn.Sequential(*[ TransformerBlock(dim=int(dim * 2 ** 3), num_heads=heads[3], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[3])]) self.up4_3 = Upsample(int(dim * 2 ** 3)) ## From Level 4 to Level 3 #self.reduce_chan_level3 = nn.Conv2d(int(dim * 2 ** 3), int(dim * 2 ** 2), kernel_size=1, bias=bias) self.decoder_level3 = nn.Sequential(*[ TransformerBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])]) self.up3_2 = Upsample(int(dim * 2 ** 2)) ## From Level 3 to Level 2 #self.reduce_chan_level2 = nn.Conv2d(int(dim * 2 ** 2), int(dim * 2 ** 1), kernel_size=1, bias=bias) self.decoder_level2 = nn.Sequential(*[ TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])]) self.up2_1 = Upsample(int(dim * 2 ** 1)) ## From Level 2 to Level 1 (NO 1x1 conv to reduce channels) self.decoder_level1 = nn.Sequential(*[ StageOneBlock(dim=int(dim * 2 ** 1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])]) self.decoder_level0 = nn.Sequential(*[ StageOneBlock(dim=int(dim * 2 ** 1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])]) ############# SKFUSION self.fusion1 = SKFF(192) self.fusion2 = SKFF(96) ############ #self.conv1 =nn.Conv2d(384, 192, kernel_size=1, bias=False) #self.conv2 =nn.Conv2d(192, 96, kernel_size=1, bias=False) #self.refinement = subnet(dim=int(dim*2**1)) ## We do not use MEFC for training Rain200L and SPA-Data self.output = nn.Conv2d(int(dim * 2 ** 1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias) def forward(self, inp_img): inp_enc_level1 = self.patch_embed(inp_img) #inp_enc_level0 = self.encoder_level0(inp_enc_level1) ## We do not use MEFC for training Rain200L and SPA-Data out_enc_level1 = self.encoder_level1(inp_enc_level1) inp_enc_level2 = self.down1_2(out_enc_level1) out_enc_level2 = self.encoder_level2(inp_enc_level2) inp_enc_level3 = self.down2_3(out_enc_level2) out_enc_level3 = self.encoder_level3(inp_enc_level3) inp_enc_level4 = self.down3_4(out_enc_level3) latent = self.latent(inp_enc_level4) inp_dec_level3 = self.up4_3(latent) inp_dec_level3 = self.fusion1([out_enc_level3,inp_dec_level3]) #inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 1) #inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3) out_dec_level3 = self.decoder_level3(inp_dec_level3) inp_dec_level2 = self.up3_2(out_dec_level3) inp_dec_level2 = self.fusion2([out_enc_level2,inp_dec_level2]) #inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 1) #inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2) out_dec_level2 = self.decoder_level2(inp_dec_level2) inp_dec_level1 = self.up2_1(out_dec_level2) inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 1) out_dec_level1 = self.decoder_level1(inp_dec_level1) out_dec_level0 = self.decoder_level0(out_dec_level1) #out_dec_level1 = self.refinement(out_dec_level1) ## We do not use MEFC for training Rain200L and SPA-Data out_dec_total = self.output(out_dec_level0) + inp_img return out_dec_total ####################### if __name__ == '__main__': input = torch.rand(1, 3, 256, 256) model = DRSformer() # output = model(input) from fvcore.nn import FlopCountAnalysis, parameter_count_table flops = FlopCountAnalysis(model, input) print("FLOPs: ", flops.total())