-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathadapter_for_transformer.py
101 lines (66 loc) · 3.62 KB
/
adapter_for_transformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import math
import random
import numpy as np
import torch
import torch.nn as nn
import transformers.adapters as adapters
from transformers.models.bert.modeling_bert import BertLayer
from transformers.models.bert.configuration_bert import BertConfig
from lib.utils import l1norm, l2norm
import logging
logger = logging.getLogger(__name__)
class Our_tuning_for_transformer(nn.Module):
def __init__(self, num_layer=13, hidden_size=None, embed_size=None, downsample_D_factor=None, **kwargs):
super(Our_tuning_for_transformer, self).__init__()
assert num_layer == len(hidden_size)
self.num_layer = num_layer
mapping_size = embed_size // downsample_D_factor
self.downsample_layers = nn.ModuleList([nn.Sequential(nn.Linear(hidden_size[i], mapping_size),
nn.LayerNorm(mapping_size))
for i in range(num_layer)])
self.projection_layer = nn.Sequential(nn.Linear(mapping_size, mapping_size),
nn.LayerNorm(mapping_size))
self.upsample_layer = nn.Sequential(nn.ReLU(), nn.Linear(mapping_size, embed_size))
self.gate_layer = nn.Linear(mapping_size, 1)
def _compute_mask(self, max_len, valid_len):
mask = torch.arange(max_len).expand(valid_len.size(0), max_len).to(valid_len.device)
mask = (mask < valid_len.long().unsqueeze(1))
return mask
def _aggregate_feature(self, features, mask=None):
if mask is None:
agg_feature = features.mean(dim=1)
else:
agg_mask = mask.unsqueeze(-1).float()
features = features * agg_mask
agg_feature = features.sum(dim=1) / agg_mask.sum(dim=1)
return agg_feature
def _cross_attention(self, query, context, mask=None, **kwargs):
cross_weights = torch.matmul(query, context.permute(0, 2, 1))
if mask is not None:
cross_weights = cross_weights * mask.unsqueeze(1)
cross_weights = l1norm(torch.relu(cross_weights), dim=-1)
cross_weights += torch.eye(cross_weights.size(-1)).to(cross_weights.device)
wcontext = torch.matmul(cross_weights, context)
return wcontext, cross_weights
def forward(self, all_hidden_features, lengths=None, **kwargs):
assert len(all_hidden_features) == self.num_layer
n_instance = all_hidden_features[0].size(1)
mask = self._compute_mask(max_len=n_instance, valid_len=lengths)
anchor_features = self.downsample_layers[-1](all_hidden_features[-1])
anchor = self._aggregate_feature(anchor_features, mask)
confidence_set = []
merged_features_set = []
for index in range(self.num_layer-1):
target_features = self.downsample_layers[index](all_hidden_features[index])
target = self._aggregate_feature(target_features, mask)
merged_features, _ = self._cross_attention(anchor_features, target_features, mask=mask)
merged_features_set.append(merged_features)
confidence_set.append(self.gate_layer(anchor * target))
confidence_norm = torch.softmax(torch.cat(confidence_set, dim=1), dim=1)
all_merged_features = torch.stack(merged_features_set, dim=1)
merged_features = torch.sum(all_merged_features * confidence_norm[:, :, None, None], dim=1)
output_features = anchor_features + self.projection_layer(merged_features)
return self.upsample_layer(output_features), lengths
side_tuning_factory_for_transformer = {
'our_tuning': Our_tuning_for_transformer,
}