Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Further development of attention maps; no weight decay for 1D parameters #45

Merged
merged 14 commits into from
Jul 23, 2024

Conversation

matsen
Copy link
Contributor

@matsen matsen commented Jul 9, 2024

No description provided.

@matsen
Copy link
Contributor Author

matsen commented Jul 9, 2024

When I was getting surprising results from the attention maps code, I wrote this code. Although this code uses the same mechanisms as the original code, it doesn't seem to work for getting things other than the final attention maps. The keys, queries, and values are all the same as reported by the printing version, which is the simplest I could imagine doing.

"""
We are going to get the attention weights using the [MultiHeadAttention](https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html) module in PyTorch. These weights are

$$
\text{softmax}\left(\frac{Q K^T}{\sqrt{d_k}}\right).
$$

So this tells us that the rows of the attention map correspond to the queries, whereas the columns correspond to the keys.

In our terminology, an attention map is the attention map for a single head. An
"attention maps" object is a collection of attention maps: a tensor where the
first dimension is the number of heads. An "attention mapss" is a list of
attention maps objects, one for each sequence in the batch. An "attention
profile" is some 1-D summary of an attention map, such as the maximum attention
score for each key position. 

# Adapted from https://gist.github.com/airalcorn2/50ec06517ce96ecc143503e21fa6cb91
"""

import copy

import torch

from netam.common import aa_idx_tensor_of_str_ambig, aa_mask_tensor_of


def reshape_tensor(tensor, head_count):
    """
    Reshape the tensor to include the head dimension.
    Assumes batch size is 1 and squeezes it out.
    """
    assert tensor.size(0) == 1, "Batch size should be 1"
    seq_len, embed_dim = tensor.size(1), tensor.size(2)
    head_dim = embed_dim // head_count
    return tensor.view(seq_len, head_count, head_dim).transpose(0, 1)


class SaveAttentionInfo:
    def __init__(self, head_count):
        self.head_count = head_count
        self.queries = []
        self.keys = []
        self.values = []
        self.attention_maps = []

    def __call__(self, module, module_in, module_out):
        # module_in is the input to the attention layer which contains queries, keys, and values
        # print max difference between module_in[0] and module_in[1]
        print("max diff", (module_in[0] - module_in[1]).abs().max())
        self.queries.append(reshape_tensor(module_in[0].clone().detach(), self.head_count))  # Queries
        self.keys.append(reshape_tensor(module_in[1].clone().detach(), self.head_count))     # Keys
        self.values.append(reshape_tensor(module_in[2].clone().detach(), self.head_count))   # Values
        self.attention_maps.append(module_out[1].clone().squeeze(0))  # Attention maps

    def clear(self):
        self.keys = []
        self.values = []
        self.queries = []
        self.attention_maps = []

def patch_attention(m):
    forward_orig = m.forward

    def wrap(*args, **kwargs):
        kwargs["need_weights"] = True
        kwargs["average_attn_weights"] = False

        return forward_orig(*args, **kwargs)

    m.forward = wrap


def attention_infos_of(model, sequences):
    """
    Get a list of attention maps (across sequences) for the specified layer of
    the model, along with keys, values, and queries.
    """
    model = copy.deepcopy(model)
    model.eval()
    layer_count = len(model.encoder.layers)
    head_count = model.encoder.layers[0].self_attn.num_heads  # Assuming all layers have the same number of heads
    save_info = [SaveAttentionInfo(head_count) for _ in range(layer_count)]
    for which_layer, layer in enumerate(model.encoder.layers):
        patch_attention(layer.self_attn)
        layer.self_attn.register_forward_hook(save_info[which_layer])

    for sequence in sequences:
        sequence_idxs = aa_idx_tensor_of_str_ambig(sequence)
        mask = aa_mask_tensor_of(sequence)
        model(sequence_idxs.unsqueeze(0), mask.unsqueeze(0))

    attention_maps = []
    queries = []
    keys = []
    values = []

    for seq_idx in range(len(sequences)):
        attention_maps.append(
            torch.stack([save.attention_maps[seq_idx] for save in save_info], dim=0)
        )
        queries.append(
            torch.stack([save.queries[seq_idx] for save in save_info], dim=0)
        )
        keys.append(
            torch.stack([save.keys[seq_idx] for save in save_info], dim=0)
        )
        values.append(
            torch.stack([save.values[seq_idx] for save in save_info], dim=0)
        )

    return (
        [amap.detach().numpy() for amap in attention_maps],
        [query.detach().numpy() for query in queries],
        [key.detach().numpy() for key in keys],
        [value.detach().numpy() for value in values],
    )


class PrintAttentionInfo:
    def __call__(self, module, module_in, module_out):
        # module_in is the input to the attention layer which contains queries, keys, and values
        print("Shapes of module_in: ", module_in[0].shape, module_in[1].shape)
        print("module_in[0]:", module_in[0])
        print("module_in[1]:", module_in[1])

        print("max diff", (module_in[0] - module_in[1]).abs().max())


def print_attention_info(model, sequences):
    model = copy.deepcopy(model)
    model.eval()

    print_info = PrintAttentionInfo()

    for which_layer, layer in enumerate(model.encoder.layers):
        layer.self_attn.register_forward_hook(print_info)

    for sequence in sequences:
        sequence_idxs = aa_idx_tensor_of_str_ambig(sequence)
        mask = aa_mask_tensor_of(sequence)
        model(sequence_idxs.unsqueeze(0), mask.unsqueeze(0))


def attention_profiles_of(model, which_layer, sequences, by):
    """
    Take the mean attention map by heads, then take the maximum attention
    score to get a profile indexed by `by`.

    If by="query", this will return the maximum attention score for each query position.
    If by="key", this will return the maximum attention score for each key position.
    """
    by_to_index_dict = {"query": 1, "key": 0}
    assert by in by_to_index_dict, f"by must be one of {by_to_index_dict.keys()}"
    axis = by_to_index_dict[by]
    attention_mapss = attention_mapss_of(model, which_layer, sequences)
    return [
        attention_maps.mean(axis=0).max(axis=axis) for attention_maps in attention_mapss
    ]

@matsen
Copy link
Contributor Author

matsen commented Jul 9, 2024

To check that the attention layers were looking reasonable I also used this code in a notebook:

for which_layer, layer in enumerate(model.encoder.layers):
    print(f"Layer {which_layer} - Self Attention Parameters:")
    for name, param in layer.self_attn.named_parameters():
        if len(param.data.shape) == 1:
            print(name, param.data[:5])
        else:
            print(name, param.data[:5, :5])

@matsen
Copy link
Contributor Author

matsen commented Jul 9, 2024

Here's some code I use to plot keys, queries, and values:

def plot_an_info(info, info_name):
    layer_count, head_count, _, _ = info.shape
    info_summary = info.mean(axis=-1)

    fig, axs = plt.subplots(layer_count, head_count, figsize=(3*head_count, 3*layer_count))
    
    for layer_idx, layer_info in enumerate(info_summary):
        for head_idx, head_info in enumerate(layer_info):
            axs[layer_idx, head_idx].plot(head_info)
        
        for ax, layer_idx in zip(axs[:,0], range(layer_count)):
            ax.set_ylabel(f"Layer {layer_idx}")

        for ax, head_idx in zip(axs[0], range(head_count)):
            ax.set_title(f"Head {head_idx}")
 
    fig.suptitle(f"{info_name} for {model_name}", fontsize=16)
    
    plt.tight_layout()
    return fig
    

_ = plot_an_info(keys, "keys")

@matsen
Copy link
Contributor Author

matsen commented Jul 9, 2024

... and the code that first pointed out that there was something funny about the non-attention-map infos

layer = 1
head = 1

def span(tensor):
    return tensor.max() - tensor.min()

span(attention_maps[layer, head]), span(queries[layer, head]), span(keys[layer, head]), span(values[layer, head])

queries[layer, head] - keys[layer, head], queries[layer, head] - values[layer, head]

@matsen
Copy link
Contributor Author

matsen commented Jul 16, 2024

Here is a link to some code I was using to look at attention maps on the sabdab: https://github.com/matsengrp/dnsm-experiments-1/pull/1#issuecomment-2231338349

@matsen matsen merged commit c6c52e8 into main Jul 23, 2024
1 check passed
@matsen matsen deleted the 44-attention-cont branch July 23, 2024 11:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant