-
Notifications
You must be signed in to change notification settings - Fork 56
/
Copy pathdata_load.py
101 lines (79 loc) · 2.95 KB
/
data_load.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
'''
An entry or sent looks like ...
SOCCER NN B-NP O
- : O O
JAPAN NNP B-NP B-LOC
GET VB B-VP O
LUCKY NNP B-NP O
WIN NNP I-NP O
, , O O
CHINA NNP B-NP B-PER
IN IN B-PP O
SURPRISE DT B-NP O
DEFEAT NN I-NP O
. . O O
Each mini-batch returns the followings:
words: list of input sents. ["The 26-year-old ...", ...]
x: encoded input sents. [N, T]. int64.
is_heads: list of head markers. [[1, 1, 0, ...], [...]]
tags: list of tags.['O O B-MISC ...', '...']
y: encoded tags. [N, T]. int64
seqlens: list of seqlens. [45, 49, 10, 50, ...]
'''
import numpy as np
import torch
from torch.utils import data
from pytorch_pretrained_bert import BertTokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-cased', do_lower_case=False)
VOCAB = ('<PAD>', 'O', 'I-LOC', 'B-PER', 'I-PER', 'I-ORG', 'I-MISC', 'B-MISC', 'B-LOC', 'B-ORG')
tag2idx = {tag: idx for idx, tag in enumerate(VOCAB)}
idx2tag = {idx: tag for idx, tag in enumerate(VOCAB)}
class NerDataset(data.Dataset):
def __init__(self, fpath):
"""
fpath: [train|valid|test].txt
"""
entries = open(fpath, 'r').read().strip().split("\n\n")
sents, tags_li = [], [] # list of lists
for entry in entries:
words = [line.split()[0] for line in entry.splitlines()]
tags = ([line.split()[-1] for line in entry.splitlines()])
sents.append(["[CLS]"] + words + ["[SEP]"])
tags_li.append(["<PAD>"] + tags + ["<PAD>"])
self.sents, self.tags_li = sents, tags_li
def __len__(self):
return len(self.sents)
def __getitem__(self, idx):
words, tags = self.sents[idx], self.tags_li[idx] # words, tags: string list
# We give credits only to the first piece.
x, y = [], [] # list of ids
is_heads = [] # list. 1: the token is the first piece of a word
for w, t in zip(words, tags):
tokens = tokenizer.tokenize(w) if w not in ("[CLS]", "[SEP]") else [w]
xx = tokenizer.convert_tokens_to_ids(tokens)
is_head = [1] + [0]*(len(tokens) - 1)
t = [t] + ["<PAD>"] * (len(tokens) - 1) # <PAD>: no decision
yy = [tag2idx[each] for each in t] # (T,)
x.extend(xx)
is_heads.extend(is_head)
y.extend(yy)
assert len(x)==len(y)==len(is_heads), f"len(x)={len(x)}, len(y)={len(y)}, len(is_heads)={len(is_heads)}"
# seqlen
seqlen = len(y)
# to string
words = " ".join(words)
tags = " ".join(tags)
return words, x, is_heads, tags, y, seqlen
def pad(batch):
'''Pads to the longest sample'''
f = lambda x: [sample[x] for sample in batch]
words = f(0)
is_heads = f(2)
tags = f(3)
seqlens = f(-1)
maxlen = np.array(seqlens).max()
f = lambda x, seqlen: [sample[x] + [0] * (seqlen - len(sample[x])) for sample in batch] # 0: <pad>
x = f(1, maxlen)
y = f(-2, maxlen)
f = torch.LongTensor
return words, f(x), is_heads, tags, f(y), seqlens