Skip to content

Commit

Permalink
Initial attempt to chop up long inputs to a transformer into pieces t…
Browse files Browse the repository at this point in the history
…hat the transformer can digest, even if it isn't necessarily going to give great results for the later tokens in the sentence. Addresses #1294
  • Loading branch information
AngledLuffa committed Feb 2, 2024
1 parent 4b7c6b4 commit bbe90ee
Showing 1 changed file with 44 additions and 12 deletions.
56 changes: 44 additions & 12 deletions stanza/models/common/bert_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,43 @@ def extract_xlnet_embeddings(model_name, tokenizer, model, data, device, keep_en

return processed

def build_cloned_features(model, tokenizer, attention_tensor, id_tensor, num_layers, detach, device):
"""
Extract an embedding from the given transformer for a certain attention mask and tokens range
In the event that the tokens are longer than the max length
supported by the model, the range is split up into overlapping
sections and the overlapping pieces are connected. No idea if
this is actually any good, but at least it returns something
instead of horribly failing
"""
if attention_tensor.shape[1] <= tokenizer.model_max_length:
features = model(id_tensor, attention_mask=attention_tensor, output_hidden_states=True)
features = cloned_feature(features.hidden_states, num_layers, detach)
return features

slices = []
slice_len = max(tokenizer.model_max_length - 20, tokenizer.model_max_length // 2)
prefix_len = tokenizer.model_max_length - slice_len
if slice_len < 5:
raise RuntimeError("Really tiny tokenizer!")
remaining_attention = attention_tensor
remaining_ids = id_tensor
while True:
attention_slice = remaining_attention[:, :tokenizer.model_max_length]
id_slice = remaining_ids[:, :tokenizer.model_max_length]
features = model(id_slice, attention_mask=attention_slice, output_hidden_states=True)
features = cloned_feature(features.hidden_states, num_layers, detach)
if len(slices) > 0:
features = features[:, prefix_len:, :]
slices.append(features)
if remaining_attention.shape[1] <= tokenizer.model_max_length:
break
remaining_attention = remaining_attention[:, slice_len:]
remaining_ids = id_tensor[:, slice_len:]
slices = torch.cat(slices, axis=1)
return slices


def extract_bert_embeddings(model_name, tokenizer, model, data, device, keep_endpoints, num_layers=None, detach=True):
"""
Expand Down Expand Up @@ -378,27 +415,22 @@ def extract_bert_embeddings(model_name, tokenizer, model, data, device, keep_end
list_offsets[idx][offset+1] = pos
list_offsets[idx][0] = 0
list_offsets[idx][-1] = list_offsets[idx][-2] + 1
#print(list_offsets[idx])
if any(x is None for x in list_offsets[idx]):
raise ValueError("OOPS, hit None when preparing to use Bert\ndata[idx]: {}\noffsets: {}\nlist_offsets[idx]: {}".format(data[idx], offsets, list_offsets[idx], tokenized))

if list_offsets[idx][-1] > tokenizer.model_max_length - 1:
logger.error("Invalid size, max size: %d, got %d.\nTokens: %s\nTokenized: %s", tokenizer.model_max_length, len(offsets), data[idx][:1000], offsets[:1000])
raise TextTooLongError(len(offsets), tokenizer.model_max_length, idx, " ".join(data[idx]))
#if list_offsets[idx][-1] > tokenizer.model_max_length - 1:
# logger.error("Invalid size, max size: %d, got %d.\nTokens: %s\nTokenized: %s", tokenizer.model_max_length, len(offsets), data[idx][:1000], offsets[:1000])
# raise TextTooLongError(len(offsets), tokenizer.model_max_length, idx, " ".join(data[idx]))

features = []
for i in range(int(math.ceil(len(data)/128))):
attention_tensor = torch.tensor(tokenized['attention_mask'][128*i:128*i+128], device=device)
id_tensor = torch.tensor(tokenized['input_ids'][128*i:128*i+128], device=device)
if detach:
with torch.no_grad():
attention_mask = torch.tensor(tokenized['attention_mask'][128*i:128*i+128], device=device)
id_tensor = torch.tensor(tokenized['input_ids'][128*i:128*i+128], device=device)
feature = model(id_tensor, attention_mask=attention_mask, output_hidden_states=True)
features += cloned_feature(feature.hidden_states, num_layers, detach)
features += build_cloned_features(model, tokenizer, attention_tensor, id_tensor, num_layers, detach, device)
else:
attention_mask = torch.tensor(tokenized['attention_mask'][128*i:128*i+128], device=device)
id_tensor = torch.tensor(tokenized['input_ids'][128*i:128*i+128], device=device)
feature = model(id_tensor, attention_mask=attention_mask, output_hidden_states=True)
features += cloned_feature(feature.hidden_states, num_layers, detach)
features += build_cloned_features(model, tokenizer, attention_tensor, id_tensor, num_layers, detach, device)

processed = []
#process the output
Expand Down

0 comments on commit bbe90ee

Please sign in to comment.