Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

visualization attention map. #187

Open
kimsekeun opened this issue Jan 16, 2024 · 0 comments
Open

visualization attention map. #187

kimsekeun opened this issue Jan 16, 2024 · 0 comments

Comments

@kimsekeun
Copy link

kimsekeun commented Jan 16, 2024

I wonder that how did you visualize attention in final vit.

In my opinion,

Given x,
Y = forward_encoder (x)
Then
Y2 = forward_decoder(y) , in this step did you used x1 or x2?

apply Transformer blocks
for blk in self.decoder_blocks:
x = blk(x)
x1 = self.decoder_norm(x)

predictor projection
X2 = self.decoder_pred(x1)

---code---

main

def vis_attention(idx, batch, model):
x = batch["image"]

x = x.unsqueeze(0) #ncthw

attention_last = model.get_last_selfattention(x.cuda())

nh = 16  #decoder numbed
ch = attention_last.shape[2]
dim = ch//nh

attention_last = attention_last.view(nh, 1568, dim)

attention_head_feature = attention_last[0]

attention_head_feature = attention_head_feature.view(14, 14, 8, dim)

# sum along the last dimension to get the attention map
attention_map = attention_head_feature.sum(dim=-1)

# normalize attention map to [0, 1]
attention_map = (attention_map - attention_map.min()) / (attention_map.max() - attention_map.min())

# attention_map =  attention_map  * torch.int((attention_map > 0.5))
# visualize attention map for each frame
for frame_idx in range(8):
    frame_attention = attention_map[:, :, frame_idx]
    
    resized_attention_map = F.interpolate(frame_attention.unsqueeze(0).unsqueeze(0), size=(224, 224), mode="bilinear")[0][0]


    plt.subplot(1, 3, 1)
    vis_img = torch.einsum('chw->hwc', x[0,:,frame_idx,:,:])
    show_image(vis_img , "original")
    
    plt.subplot(1, 3, 2)
    plt.imshow(resized_attention_map.cpu().detach().numpy(), cmap='hot', interpolation='bilinear')
    plt.title(f'Attention Map - Frame {frame_idx + 1}')
    
    plt.subplot(1, 3, 3)
    plt.imshow(vis_img[:,:,0] *resized_attention_map.cpu().detach().numpy() , cmap='hot', interpolation='bilinear')
    
    plt.savefig( os.path.join( dest_attn_dir, str(idx) + "_" + str(frame_idx)))
    plt.show()

--
def prepare_tokens(self, x):

    latent, mask, ids_restore = self.forward_encoder(x, 0.75)
    pred = self.forward_decoder_get_last_attn(latent, ids_restore)  # [N, L, p*p*3]
    return pred, mask
        
    return x

def get_last_selfattention(self, x):
x, mask = self.prepare_tokens(x)
return x

def forward_decoder_get_last_attn(self, x, ids_restore):
N = x.shape[0]
T = self.patch_embed.t_grid_size
H = W = self.patch_embed.grid_size

    # embed tokens
    x = self.decoder_embed(x)
    C = x.shape[-1]

    # append mask tokens to sequence
    mask_tokens = self.mask_token.repeat(N, T * H * W + 0 - x.shape[1], 1)
    x_ = torch.cat([x[:, :, :], mask_tokens], dim=1)  # no cls token
    x_ = x_.view([N, T * H * W, C])
    x_ = torch.gather(
        x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x_.shape[2])
    )  # unshuffle
    x = x_.view([N, T * H * W, C])
    # append cls token
    if self.cls_embed:
        decoder_cls_token = self.decoder_cls_token
        decoder_cls_tokens = decoder_cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat((decoder_cls_tokens, x), dim=1)

    if self.sep_pos_embed:
        decoder_pos_embed = self.decoder_pos_embed_spatial.repeat(
            1, self.input_size[0], 1
        ) + torch.repeat_interleave(
            self.decoder_pos_embed_temporal,
            self.input_size[1] * self.input_size[2],
            dim=1,
        )
        if self.cls_embed:
            decoder_pos_embed = torch.cat(
                [
                    self.decoder_pos_embed_class.expand(
                        decoder_pos_embed.shape[0], -1, -1
                    ),
                    decoder_pos_embed,
                ],
                1,
            )
    else:
        decoder_pos_embed = self.decoder_pos_embed[:, :, :]

    # add pos embed
    x = x + decoder_pos_embed

    attn = self.decoder_blocks[0].attn
    requires_t_shape = hasattr(attn, "requires_t_shape") and attn.requires_t_shape
    if requires_t_shape:
        x = x.view([N, T, H * W, C])

    # apply Transformer blocks
    for blk in self.decoder_blocks:
        x = blk(x)
    x = self.decoder_norm(x)

    x = x[:, 1:, :]
    
    return x

Thank you.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant