-
Notifications
You must be signed in to change notification settings - Fork 191
/
Copy pathutil.py
147 lines (109 loc) · 3.7 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
# -*- coding: utf-8 -*-
from typing import *
import json
from pathlib import Path
from itertools import repeat
from collections import OrderedDict
import torch
from .class_utils import keys_vocab_cls, iob_labels_vocab_cls
from data_utils import documents
def ensure_dir(dirname):
dirname = Path(dirname)
if not dirname.is_dir():
dirname.mkdir(parents=True, exist_ok=False)
def read_json(fname):
fname = Path(fname)
with fname.open('rt') as handle:
return json.load(handle, object_hook=OrderedDict)
def write_json(content, fname):
fname = Path(fname)
with fname.open('wt') as handle:
json.dump(content, handle, indent=4, sort_keys=False)
def inf_loop(data_loader):
''' wrapper function for endless data loader. '''
for loader in repeat(data_loader):
yield from loader
def iob2entity(tag):
'''
iob label to entity
:param tag:
:return:
'''
if len(tag) == 1 and tag != 'O':
raise TypeError('Invalid tag!')
elif len(tag) == 1 and tag == 'O':
return tag
elif len(tag) > 1:
e = tag[2:]
return e
def iob_index_to_str(tags: List[List[int]]):
decoded_tags_list = []
for doc in tags:
decoded_tags = []
for tag in doc:
s = iob_labels_vocab_cls.itos[tag]
if s == '<unk>' or s == '<pad>':
s = 'O'
decoded_tags.append(s)
decoded_tags_list.append(decoded_tags)
return decoded_tags_list
def text_index_to_str(texts: torch.Tensor, mask: torch.Tensor):
# union_texts: (B, num_boxes * T)
union_texts = texts_to_union_texts(texts, mask)
B, NT = union_texts.shape
decoded_tags_list = []
for i in range(B):
decoded_text = []
for text_index in union_texts[i]:
text_str = keys_vocab_cls.itos[text_index]
if text_str == '<unk>' or text_str == '<pad>':
text_str = 'O'
decoded_text.append(text_str)
decoded_tags_list.append(decoded_text)
return decoded_tags_list
def texts_to_union_texts(texts, mask):
'''
:param texts: (B, N, T)
:param mask: (B, N, T)
:return:
'''
B, N, T = texts.shape
texts = texts.reshape(B, N * T)
mask = mask.reshape(B, N * T)
# union tags as a whole sequence, (B, N*T)
union_texts = torch.full_like(texts, keys_vocab_cls['<pad>'], device=texts.device)
max_seq_length = 0
for i in range(B):
valid_text = torch.masked_select(texts[i], mask[i].bool())
valid_length = valid_text.size(0)
union_texts[i, :valid_length] = valid_text
if valid_length > max_seq_length:
max_seq_length = valid_length
# max_seq_length = documents.MAX_BOXES_NUM * documents.MAX_TRANSCRIPT_LEN
# (B, N*T)
union_texts = union_texts[:, :max_seq_length]
# (B, N*T)
return union_texts
def iob_tags_to_union_iob_tags(iob_tags, mask):
'''
:param iob_tags: (B, N, T)
:param mask: (B, N, T)
:return:
'''
B, N, T = iob_tags.shape
iob_tags = iob_tags.reshape(B, N * T)
mask = mask.reshape(B, N * T)
# union tags as a whole sequence, (B, N*T)
union_iob_tags = torch.full_like(iob_tags, iob_labels_vocab_cls['<pad>'], device=iob_tags.device)
max_seq_length = 0
for i in range(B):
valid_tag = torch.masked_select(iob_tags[i], mask[i].bool())
valid_length = valid_tag.size(0)
union_iob_tags[i, :valid_length] = valid_tag
if valid_length > max_seq_length:
max_seq_length = valid_length
# max_seq_length = documents.MAX_BOXES_NUM * documents.MAX_TRANSCRIPT_LEN
# (B, N*T)
union_iob_tags = union_iob_tags[:, :max_seq_length]
# (B, N*T)
return union_iob_tags