-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Further development of attention maps; no weight decay for 1D parameters #45
Conversation
When I was getting surprising results from the attention maps code, I wrote this code. Although this code uses the same mechanisms as the original code, it doesn't seem to work for getting things other than the final attention maps. The keys, queries, and values are all the same as reported by the printing version, which is the simplest I could imagine doing. """
We are going to get the attention weights using the [MultiHeadAttention](https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html) module in PyTorch. These weights are
$$
\text{softmax}\left(\frac{Q K^T}{\sqrt{d_k}}\right).
$$
So this tells us that the rows of the attention map correspond to the queries, whereas the columns correspond to the keys.
In our terminology, an attention map is the attention map for a single head. An
"attention maps" object is a collection of attention maps: a tensor where the
first dimension is the number of heads. An "attention mapss" is a list of
attention maps objects, one for each sequence in the batch. An "attention
profile" is some 1-D summary of an attention map, such as the maximum attention
score for each key position.
# Adapted from https://gist.github.com/airalcorn2/50ec06517ce96ecc143503e21fa6cb91
"""
import copy
import torch
from netam.common import aa_idx_tensor_of_str_ambig, aa_mask_tensor_of
def reshape_tensor(tensor, head_count):
"""
Reshape the tensor to include the head dimension.
Assumes batch size is 1 and squeezes it out.
"""
assert tensor.size(0) == 1, "Batch size should be 1"
seq_len, embed_dim = tensor.size(1), tensor.size(2)
head_dim = embed_dim // head_count
return tensor.view(seq_len, head_count, head_dim).transpose(0, 1)
class SaveAttentionInfo:
def __init__(self, head_count):
self.head_count = head_count
self.queries = []
self.keys = []
self.values = []
self.attention_maps = []
def __call__(self, module, module_in, module_out):
# module_in is the input to the attention layer which contains queries, keys, and values
# print max difference between module_in[0] and module_in[1]
print("max diff", (module_in[0] - module_in[1]).abs().max())
self.queries.append(reshape_tensor(module_in[0].clone().detach(), self.head_count)) # Queries
self.keys.append(reshape_tensor(module_in[1].clone().detach(), self.head_count)) # Keys
self.values.append(reshape_tensor(module_in[2].clone().detach(), self.head_count)) # Values
self.attention_maps.append(module_out[1].clone().squeeze(0)) # Attention maps
def clear(self):
self.keys = []
self.values = []
self.queries = []
self.attention_maps = []
def patch_attention(m):
forward_orig = m.forward
def wrap(*args, **kwargs):
kwargs["need_weights"] = True
kwargs["average_attn_weights"] = False
return forward_orig(*args, **kwargs)
m.forward = wrap
def attention_infos_of(model, sequences):
"""
Get a list of attention maps (across sequences) for the specified layer of
the model, along with keys, values, and queries.
"""
model = copy.deepcopy(model)
model.eval()
layer_count = len(model.encoder.layers)
head_count = model.encoder.layers[0].self_attn.num_heads # Assuming all layers have the same number of heads
save_info = [SaveAttentionInfo(head_count) for _ in range(layer_count)]
for which_layer, layer in enumerate(model.encoder.layers):
patch_attention(layer.self_attn)
layer.self_attn.register_forward_hook(save_info[which_layer])
for sequence in sequences:
sequence_idxs = aa_idx_tensor_of_str_ambig(sequence)
mask = aa_mask_tensor_of(sequence)
model(sequence_idxs.unsqueeze(0), mask.unsqueeze(0))
attention_maps = []
queries = []
keys = []
values = []
for seq_idx in range(len(sequences)):
attention_maps.append(
torch.stack([save.attention_maps[seq_idx] for save in save_info], dim=0)
)
queries.append(
torch.stack([save.queries[seq_idx] for save in save_info], dim=0)
)
keys.append(
torch.stack([save.keys[seq_idx] for save in save_info], dim=0)
)
values.append(
torch.stack([save.values[seq_idx] for save in save_info], dim=0)
)
return (
[amap.detach().numpy() for amap in attention_maps],
[query.detach().numpy() for query in queries],
[key.detach().numpy() for key in keys],
[value.detach().numpy() for value in values],
)
class PrintAttentionInfo:
def __call__(self, module, module_in, module_out):
# module_in is the input to the attention layer which contains queries, keys, and values
print("Shapes of module_in: ", module_in[0].shape, module_in[1].shape)
print("module_in[0]:", module_in[0])
print("module_in[1]:", module_in[1])
print("max diff", (module_in[0] - module_in[1]).abs().max())
def print_attention_info(model, sequences):
model = copy.deepcopy(model)
model.eval()
print_info = PrintAttentionInfo()
for which_layer, layer in enumerate(model.encoder.layers):
layer.self_attn.register_forward_hook(print_info)
for sequence in sequences:
sequence_idxs = aa_idx_tensor_of_str_ambig(sequence)
mask = aa_mask_tensor_of(sequence)
model(sequence_idxs.unsqueeze(0), mask.unsqueeze(0))
def attention_profiles_of(model, which_layer, sequences, by):
"""
Take the mean attention map by heads, then take the maximum attention
score to get a profile indexed by `by`.
If by="query", this will return the maximum attention score for each query position.
If by="key", this will return the maximum attention score for each key position.
"""
by_to_index_dict = {"query": 1, "key": 0}
assert by in by_to_index_dict, f"by must be one of {by_to_index_dict.keys()}"
axis = by_to_index_dict[by]
attention_mapss = attention_mapss_of(model, which_layer, sequences)
return [
attention_maps.mean(axis=0).max(axis=axis) for attention_maps in attention_mapss
] |
To check that the attention layers were looking reasonable I also used this code in a notebook: for which_layer, layer in enumerate(model.encoder.layers):
print(f"Layer {which_layer} - Self Attention Parameters:")
for name, param in layer.self_attn.named_parameters():
if len(param.data.shape) == 1:
print(name, param.data[:5])
else:
print(name, param.data[:5, :5]) |
Here's some code I use to plot keys, queries, and values: def plot_an_info(info, info_name):
layer_count, head_count, _, _ = info.shape
info_summary = info.mean(axis=-1)
fig, axs = plt.subplots(layer_count, head_count, figsize=(3*head_count, 3*layer_count))
for layer_idx, layer_info in enumerate(info_summary):
for head_idx, head_info in enumerate(layer_info):
axs[layer_idx, head_idx].plot(head_info)
for ax, layer_idx in zip(axs[:,0], range(layer_count)):
ax.set_ylabel(f"Layer {layer_idx}")
for ax, head_idx in zip(axs[0], range(head_count)):
ax.set_title(f"Head {head_idx}")
fig.suptitle(f"{info_name} for {model_name}", fontsize=16)
plt.tight_layout()
return fig
_ = plot_an_info(keys, "keys") |
... and the code that first pointed out that there was something funny about the non-attention-map infos layer = 1
head = 1
def span(tensor):
return tensor.max() - tensor.min()
span(attention_maps[layer, head]), span(queries[layer, head]), span(keys[layer, head]), span(values[layer, head])
queries[layer, head] - keys[layer, head], queries[layer, head] - values[layer, head] |
Here is a link to some code I was using to look at attention maps on the sabdab: https://github.com/matsengrp/dnsm-experiments-1/pull/1#issuecomment-2231338349 |
No description provided.