-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathdata_util.py
40 lines (36 loc) · 1.37 KB
/
data_util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import json
import numpy as np
import torch
import string
with open('counterfitted_neighbors.json', 'r') as f:
ws = json.load(f)
class WordSubstitude:
def __init__(self, table):
self.table = table
self.table_key = set(list(table.keys()))
self.exclude = set(string.punctuation)
def get_perturbed_batch(self, batch, rep=1):
num_text = len(batch)
out_batch = []
for k in range(rep):
for i in range(num_text):
tem_text = batch[i][0].split(' ')
if tem_text[0]:
for j in range(len(tem_text)):
if tem_text[j][-1] in self.exclude:
tem_text[j] = self.sample_from_table(tem_text[j][0:-1]) + tem_text[j][-1]
else:
tem_text[j] = self.sample_from_table(tem_text[j])
#out_batch[k*num_text + i] = [' '.join(tem_text)]
out_batch.append([' '.join(tem_text)])
else:
out_batch.append([batch[i][0]])
return np.array(out_batch)
def sample_from_table(self, word):
if word in self.table_key:
tem_words = self.table[word]['set']
num_words = len(tem_words)
index = np.random.randint(0, num_words)
return tem_words[index]
else:
return word