We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
I wonder that how did you visualize attention in final vit.
In my opinion,
Given x, Y = forward_encoder (x) Then Y2 = forward_decoder(y) , in this step did you used x1 or x2?
apply Transformer blocks for blk in self.decoder_blocks: x = blk(x) x1 = self.decoder_norm(x)
predictor projection X2 = self.decoder_pred(x1)
---code---
def vis_attention(idx, batch, model): x = batch["image"]
x = x.unsqueeze(0) #ncthw attention_last = model.get_last_selfattention(x.cuda()) nh = 16 #decoder numbed ch = attention_last.shape[2] dim = ch//nh attention_last = attention_last.view(nh, 1568, dim) attention_head_feature = attention_last[0] attention_head_feature = attention_head_feature.view(14, 14, 8, dim) # sum along the last dimension to get the attention map attention_map = attention_head_feature.sum(dim=-1) # normalize attention map to [0, 1] attention_map = (attention_map - attention_map.min()) / (attention_map.max() - attention_map.min()) # attention_map = attention_map * torch.int((attention_map > 0.5)) # visualize attention map for each frame for frame_idx in range(8): frame_attention = attention_map[:, :, frame_idx] resized_attention_map = F.interpolate(frame_attention.unsqueeze(0).unsqueeze(0), size=(224, 224), mode="bilinear")[0][0] plt.subplot(1, 3, 1) vis_img = torch.einsum('chw->hwc', x[0,:,frame_idx,:,:]) show_image(vis_img , "original") plt.subplot(1, 3, 2) plt.imshow(resized_attention_map.cpu().detach().numpy(), cmap='hot', interpolation='bilinear') plt.title(f'Attention Map - Frame {frame_idx + 1}') plt.subplot(1, 3, 3) plt.imshow(vis_img[:,:,0] *resized_attention_map.cpu().detach().numpy() , cmap='hot', interpolation='bilinear') plt.savefig( os.path.join( dest_attn_dir, str(idx) + "_" + str(frame_idx))) plt.show()
-- def prepare_tokens(self, x):
latent, mask, ids_restore = self.forward_encoder(x, 0.75) pred = self.forward_decoder_get_last_attn(latent, ids_restore) # [N, L, p*p*3] return pred, mask return x
def get_last_selfattention(self, x): x, mask = self.prepare_tokens(x) return x
def forward_decoder_get_last_attn(self, x, ids_restore): N = x.shape[0] T = self.patch_embed.t_grid_size H = W = self.patch_embed.grid_size
# embed tokens x = self.decoder_embed(x) C = x.shape[-1] # append mask tokens to sequence mask_tokens = self.mask_token.repeat(N, T * H * W + 0 - x.shape[1], 1) x_ = torch.cat([x[:, :, :], mask_tokens], dim=1) # no cls token x_ = x_.view([N, T * H * W, C]) x_ = torch.gather( x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x_.shape[2]) ) # unshuffle x = x_.view([N, T * H * W, C]) # append cls token if self.cls_embed: decoder_cls_token = self.decoder_cls_token decoder_cls_tokens = decoder_cls_token.expand(x.shape[0], -1, -1) x = torch.cat((decoder_cls_tokens, x), dim=1) if self.sep_pos_embed: decoder_pos_embed = self.decoder_pos_embed_spatial.repeat( 1, self.input_size[0], 1 ) + torch.repeat_interleave( self.decoder_pos_embed_temporal, self.input_size[1] * self.input_size[2], dim=1, ) if self.cls_embed: decoder_pos_embed = torch.cat( [ self.decoder_pos_embed_class.expand( decoder_pos_embed.shape[0], -1, -1 ), decoder_pos_embed, ], 1, ) else: decoder_pos_embed = self.decoder_pos_embed[:, :, :] # add pos embed x = x + decoder_pos_embed attn = self.decoder_blocks[0].attn requires_t_shape = hasattr(attn, "requires_t_shape") and attn.requires_t_shape if requires_t_shape: x = x.view([N, T, H * W, C]) # apply Transformer blocks for blk in self.decoder_blocks: x = blk(x) x = self.decoder_norm(x) x = x[:, 1:, :] return x
Thank you.
The text was updated successfully, but these errors were encountered:
No branches or pull requests
I wonder that how did you visualize attention in final vit.
In my opinion,
Given x,
Y = forward_encoder (x)
Then
Y2 = forward_decoder(y) , in this step did you used x1 or x2?
apply Transformer blocks
for blk in self.decoder_blocks:
x = blk(x)
x1 = self.decoder_norm(x)
predictor projection
X2 = self.decoder_pred(x1)
---code---
main
def vis_attention(idx, batch, model):
x = batch["image"]
--
def prepare_tokens(self, x):
def get_last_selfattention(self, x):
x, mask = self.prepare_tokens(x)
return x
def forward_decoder_get_last_attn(self, x, ids_restore):
N = x.shape[0]
T = self.patch_embed.t_grid_size
H = W = self.patch_embed.grid_size
Thank you.
The text was updated successfully, but these errors were encountered: