Skip to content

Commit

Permalink
Further development of attention maps; no weight decay for 1D paramet…
Browse files Browse the repository at this point in the history
…ers (#45)
  • Loading branch information
matsen authored Jul 23, 2024
1 parent d02dd88 commit c6c52e8
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 34 deletions.
76 changes: 43 additions & 33 deletions netam/attention_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,28 +8,43 @@
So this tells us that the rows of the attention map correspond to the queries, whereas the columns correspond to the keys.
In our terminology, an attention map is the attention map for a single head. An
"attention maps" object is a collection of attention maps, a tensor where the
first dimension is the number of heads. An "attention mapss" is a list of
attention maps objects, one for each sequence in the batch. An "attention
profile" is some 1-D summary of an attention map, such as the maximum attention
score for each key position.
"attention maps" object is a collection of attention maps: a tensor where the
first dimension is the number of heads. This assumes all layers have the same
number of heads. An "attention mapss" is a list of attention maps objects, one
for each sequence in the batch. An "attention profile" is some 1-D summary of an
attention map, such as the maximum attention score for each key position.
# Adapted from https://gist.github.com/airalcorn2/50ec06517ce96ecc143503e21fa6cb91
"""

import copy

import torch

from netam.common import aa_idx_tensor_of_str_ambig, aa_mask_tensor_of


class SaveOutput:
def __init__(self):
self.outputs = []
def reshape_tensor(tensor, head_count):
"""
Reshape the tensor to include the head dimension.
Assumes batch size is 1 and squeezes it out.
"""
assert tensor.size(0) == 1, "Batch size should be 1"
seq_len, embed_dim = tensor.size(1), tensor.size(2)
head_dim = embed_dim // head_count
return tensor.view(seq_len, head_count, head_dim).transpose(0, 1)


class SaveAttentionInfo:
def __init__(self, head_count):
self.head_count = head_count
self.attention_maps = []

def __call__(self, module, module_in, module_out):
self.outputs.append(module_out[1])
self.attention_maps.append(module_out[1].clone().squeeze(0))

def clear(self):
self.outputs = []
self.attention_maps = []


def patch_attention(m):
Expand All @@ -44,36 +59,31 @@ def wrap(*args, **kwargs):
m.forward = wrap


def attention_mapss_of(model, which_layer, sequences):
def attention_mapss_of(model, sequences):
"""
Get a list of attention maps (across sequences) for the specified layer of
the model.
Get a list of attention maps (across sequences) as described in the module
docstring.
"""
model = copy.deepcopy(model)
save_output = SaveOutput()
patch_attention(model.encoder.layers[which_layer].self_attn)
hook_handle = model.encoder.layers[which_layer].self_attn.register_forward_hook(
save_output
)
model.eval()
layer_count = len(model.encoder.layers)
# The below assumes all layers have the same number of heads.
head_count = model.encoder.layers[0].self_attn.num_heads
save_info = [SaveAttentionInfo(head_count) for _ in range(layer_count)]
for which_layer, layer in enumerate(model.encoder.layers):
patch_attention(layer.self_attn)
layer.self_attn.register_forward_hook(save_info[which_layer])

for sequence in sequences:
sequence_idxs = aa_idx_tensor_of_str_ambig(sequence)
mask = aa_mask_tensor_of(sequence)
model(sequence_idxs.unsqueeze(0), mask.unsqueeze(0))
return [out[0].detach().numpy() for out in save_output.outputs]

attention_maps = []

def attention_profiles_of(model, which_layer, sequences, by):
"""
Take the mean attention map by heads, then take the maximum attention
score to get a profile indexed by `by`.
for seq_idx in range(len(sequences)):
attention_maps.append(
torch.stack([save.attention_maps[seq_idx] for save in save_info], dim=0)
)

If by="key", this will return the maximum attention score for each key position.
If by="query", this will return the maximum attention score for each query position.
"""
by_to_index_dict = {"key": 0, "query": 1}
assert by in by_to_index_dict, f"by must be one of {by_to_index_dict.keys()}"
axis = by_to_index_dict[by]
attention_mapss = attention_mapss_of(model, which_layer, sequences)
return [
attention_maps.mean(axis=0).max(axis=axis) for attention_maps in attention_mapss
]
return [amap.detach().numpy() for amap in attention_maps]
14 changes: 13 additions & 1 deletion netam/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,9 +421,21 @@ def reset_optimization(self, learning_rate=None):
if learning_rate is None:
learning_rate = self.learning_rate

# copied from https://github.com/karpathy/nanoGPT/blob/9755682b981a45507f6eb9b11eadef8cb83cebd5/model.py#L264
param_dict = {
pn: p for pn, p in self.model.named_parameters() if p.requires_grad
}
# Do not apply weight decay to 1D parameters (biases and layernorm weights).
decay_params = [p for p in param_dict.values() if p.dim() >= 2]
nodecay_params = [p for p in param_dict.values() if p.dim() < 2]
optim_groups = [
{"params": decay_params, "weight_decay": self.weight_decay},
{"params": nodecay_params, "weight_decay": 0.0},
]

self.optimizer = optimizer_of_name(
self.optimizer_name,
self.model.parameters(),
optim_groups,
lr=learning_rate,
weight_decay=self.weight_decay,
)
Expand Down

0 comments on commit c6c52e8

Please sign in to comment.