-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodeling_xlmr.py
328 lines (303 loc) · 13.1 KB
/
modeling_xlmr.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
from torch._C import Value
import torch.nn as nn
import torch
import os
from transformers import AutoModelForMaskedLM
from transformers.modeling_outputs import MaskedLMOutput, SequenceClassifierOutput
from torch.nn import CrossEntropyLoss
import logging
import random
logger = logging.getLogger(__name__)
class PromptXLMR(nn.Module):
def __init__(
self,
model_args,
config,
prompt_helper,
use_soft_prompt: bool,
prompt_length: int,
tune_LM: bool,
multi_lingual_optim: bool,
multi_lingual_label_word: bool,
mixup_strategy: str,
mixup_alpha: float,
):
super(PromptXLMR, self).__init__()
self.prompt_helper = prompt_helper
self.config = config
self.mlm_model = AutoModelForMaskedLM.from_pretrained(
model_args.model_name_or_path,
from_tf=bool(".ckpt" in model_args.model_name_or_path),
config=config,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
)
self.use_soft_prompt = use_soft_prompt
self.prompt_length = prompt_length
self.multi_lingual_optim = multi_lingual_optim
self.multi_lingual_label_word = multi_lingual_label_word
self.mixup_strategy = mixup_strategy
self.mixup_alpha = mixup_alpha
if use_soft_prompt:
if not tune_LM:
for param in self.mlm_model.parameters():
param.requires_grad = False
self.soft_prompt = self._init_soft_prompt(prompt_length)
def forward(
self,
input_ids=None,
attention_mask=None,
augment=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
(masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`):
Used to hide legacy arguments that have been deprecated.
"""
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
if self.use_soft_prompt:
inputs_embeds = self._concat_soft_prompt_to_inputs(input_ids)
attention_mask, labels = self._extend_accordingly(attention_mask, labels)
# set input_ids to None to disable its functionality in forwarding
# also enabling input_embeds. refer to the api of modeling_roberta.py
input_ids = None
if self.training and self.mixup_strategy == "input_embedding":
raise NotImplementedError(
"Input embedding mixup is not implemented yet! How to deal with the position of the <mask> token?"
)
inputs_embeds = self.mixup_input_embedding(input_ids)
attention_mask, labels = self._extend_accordingly_mixup(
attention_mask, labels
)
input_ids = None
outputs = self.mlm_model.roberta(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
if self.training and self.mixup_strategy == "hidden":
sequence_output = self.mixup_input(input_ids, sequence_output)
prediction_scores = self.mlm_model.lm_head(sequence_output)
masked_lm_loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
if self.multi_lingual_optim:
if self.training and self.mixup_strategy == "hidden":
masked_lm_loss = self.mixup_loss_multilingual(
prediction_scores, labels, loss_fct
)
else:
all_lan_labels = self.prompt_helper.convert_en_mlm_label_to_all(
labels
)
# supposing the second label is the target language
masked_lm_loss = torch.cat(
[
loss_fct(
prediction_scores.view(-1, self.config.vocab_size),
labels.view(-1),
).unsqueeze(0)
for labels in all_lan_labels
]
)
masked_lm_loss = torch.mean(masked_lm_loss)
else:
if self.training and self.mixup_strategy == "hidden":
masked_lm_loss = self.mixup_loss_monolingual(
prediction_scores, labels, loss_fct
)
else:
masked_lm_loss = loss_fct(
prediction_scores.view(-1, self.config.vocab_size),
labels.view(-1),
)
if not return_dict:
output = (prediction_scores,) + outputs[2:]
return (
((masked_lm_loss,) + output) if masked_lm_loss is not None else output
)
if self.multi_lingual_label_word:
prediction_scores = self.prompt_helper.convert_mlm_prediction_scores_to_seqcls_maximum(
prediction_scores, labels
)
else:
prediction_scores = self.prompt_helper.convert_mlm_prediction_scores_to_seqcls(
prediction_scores, labels
)
return SequenceClassifierOutput(
loss=masked_lm_loss,
logits=prediction_scores,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
# return MaskedLMOutput(
# loss=masked_lm_loss,
# logits=prediction_scores,
# hidden_states=outputs.hidden_states,
# attentions=outputs.attentions,
# )
def load_state(self, state_dict_path):
state_dict = torch.load(
os.path.join(state_dict_path, "pytorch_model.bin"), map_location="cpu"
)
load_result = self.load_state_dict(state_dict, strict=False)
if len(load_result.missing_keys) != 0:
logger.warn(
f"There were missing keys in the checkpoint model loaded: {load_result.missing_keys}."
)
if len(load_result.unexpected_keys) != 0:
logger.warn(
f"There were unexpected keys in the checkpoint model loaded: {load_result.unexpected_keys}."
)
def _init_soft_prompt(self, prompt_length: int):
word_embedding_weights = (
self.mlm_model.roberta.embeddings.word_embeddings.weight
)
sampled_indices = random.sample(
list(range(word_embedding_weights.shape[0])), prompt_length
)
initialized_prompt = nn.parameter.Parameter(
word_embedding_weights[sampled_indices].clone().detach()
)
return initialized_prompt
def _concat_soft_prompt_to_inputs(self, input_ids):
input_embeds = self.mlm_model.roberta.embeddings.word_embeddings(input_ids)
soft_prompt_embeds = self.soft_prompt.repeat(input_ids.shape[0], 1, 1)
input_embeds = torch.cat(
[input_embeds[:, 0].unsqueeze(1), soft_prompt_embeds, input_embeds[:, 1:]],
dim=1,
)
return input_embeds
def _extend_accordingly(self, attention_mask: torch.Tensor, labels: torch.Tensor):
batch_size = attention_mask.shape[0]
prompt_attention_mask = torch.full((batch_size, self.prompt_length), 1).to(
attention_mask.device
)
prompt_labels = torch.full((batch_size, self.prompt_length), -100).to(
labels.device
)
attention_mask = torch.cat([prompt_attention_mask, attention_mask], dim=1)
labels = torch.cat([prompt_labels, labels], dim=1)
return attention_mask, labels
def mixup_input(self, input_ids, sequence_output):
Beta = torch.distributions.Beta(self.mixup_alpha, self.mixup_alpha)
assert len(input_ids) == 2, "We are supposing the batch size to be 2"
mask_pos1, mask_pos2 = torch.where(
input_ids == self.prompt_helper.tokenizer.mask_token_id
)[1]
self.lam = Beta.sample() # would be used in mixup_loss as well
fake_encoding = sequence_output[0].clone()
fake_encoding[mask_pos1] = (
self.lam * sequence_output[0][mask_pos1]
+ (1 - self.lam) * sequence_output[1][mask_pos2]
)
new_sequence_output = torch.cat(
[sequence_output, fake_encoding.unsqueeze(0)], dim=0
)
return new_sequence_output
def mixup_loss_multilingual(self, prediction_scores, labels, loss_fct):
all_lan_labels = self.prompt_helper.convert_en_mlm_label_to_all(labels)
masked_lm_loss_normal = torch.cat(
[
loss_fct(
prediction_scores[:2].view(-1, self.config.vocab_size),
labels.view(-1),
).unsqueeze(0)
for labels in all_lan_labels
]
)
masked_lm_loss_normal = torch.mean(masked_lm_loss_normal)
masked_lm_loss_fake1 = torch.cat(
[
loss_fct(
prediction_scores[2].view(-1, self.config.vocab_size),
labels[0].view(-1),
).unsqueeze(0)
for labels in all_lan_labels
]
)
mask_pos1, mask_pos2 = torch.where(labels != -100)[1]
fake_labels_second_compnent = [
lan_label[0].clone() for lan_label in all_lan_labels
]
for i in range(len(all_lan_labels)):
fake_labels_second_compnent[i][mask_pos1] = all_lan_labels[i][1][mask_pos2]
masked_lm_loss_fake2 = torch.cat(
[
loss_fct(
prediction_scores[2].view(-1, self.config.vocab_size),
label.view(-1),
).unsqueeze(0)
for label in fake_labels_second_compnent
]
)
fake_loss = (
self.lam * masked_lm_loss_fake1 + (1 - self.lam) * masked_lm_loss_fake2
)
fake_loss = torch.mean(fake_loss)
# to adjust the actual average
overall_loss = masked_lm_loss_normal * 2 / 3 + fake_loss / 3
return overall_loss
def mixup_loss_monolingual(self, prediction_scores, labels, loss_fct):
masked_lm_loss_normal = loss_fct(
prediction_scores[:2].view(-1, self.config.vocab_size), labels.view(-1),
)
masked_lm_loss_fake1 = loss_fct(
prediction_scores[2].view(-1, self.config.vocab_size), labels[0].view(-1),
)
mask_pos1, mask_pos2 = torch.where(labels != -100)[1]
fake_labels_second_compnent = labels[0].clone()
fake_labels_second_compnent[mask_pos1] = labels[1][mask_pos2]
masked_lm_loss_fake2 = loss_fct(
prediction_scores[2].view(-1, self.config.vocab_size),
fake_labels_second_compnent.view(-1),
)
fake_loss = (
self.lam * masked_lm_loss_fake1 + (1 - self.lam) * masked_lm_loss_fake2
)
# to adjust the actual average
overall_loss = masked_lm_loss_normal * 2 / 3 + fake_loss / 3
return overall_loss
def mixup_input_embedding(self, input_ids):
input_embeds = self.mlm_model.roberta.embeddings.word_embeddings(input_ids)
Beta = torch.distributions.Beta(self.mixup_alpha, self.mixup_alpha)
assert len(input_ids) == 2, "We are supposing the batch size to be 2"
self.lam = Beta.sample() # would be used in mixup_loss as well
fake_encoding = input_embeds[0].clone()
fake_encoding = self.lam * input_embeds[0] + (1 - self.lam) * input_embeds[1]
new_input_embeds = torch.cat([input_embeds, fake_encoding.unsqueeze(0)], dim=0)
return new_input_embeds
# Have a problem here, not implemented for now
def _extend_accordingly_mixup(
self, attention_mask: torch.Tensor, labels: torch.Tensor
):
batch_size = attention_mask.shape[0]
pass
# attention_mask = torch.cat([prompt_attention_mask, attention_mask], dim=1)
# labels = torch.cat([prompt_labels, labels], dim=1)
# return attention_mask, labels