-
Notifications
You must be signed in to change notification settings - Fork 613
/
Copy pathcrf.py
275 lines (230 loc) · 10.1 KB
/
crf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Orginal implementation from keras_contrib/layers/crf
# ==============================================================================
"""Implementing Conditional Random Field layer."""
import tensorflow as tf
from typeguard import typechecked
from tensorflow_addons.text.crf import crf_decode
from tensorflow_addons.utils import types
@tf.keras.utils.register_keras_serializable(package="Addons")
class CRF(tf.keras.layers.Layer):
"""Linear chain conditional random field (CRF).
Inherits from: `tf.keras.layers.Layer`.
References:
- [Conditional Random Field](https://en.wikipedia.org/wiki/Conditional_random_field)
Example:
>>> layer = tfa.layers.CRF(4)
>>> inputs = np.random.rand(2, 4, 8).astype(np.float32)
>>> decoded_sequence, potentials, sequence_length, chain_kernel = layer(inputs)
>>> decoded_sequence.shape
TensorShape([2, 4])
>>> potentials.shape
TensorShape([2, 4, 4])
>>> sequence_length
<tf.Tensor: shape=(2,), dtype=int64, numpy=array([4, 4])>
>>> chain_kernel.shape
TensorShape([4, 4])
Args:
units: Positive integer, dimensionality of the reservoir.
chain_initializer: Orthogonal matrix. Default to `orthogonal`.
use_boundary: `Boolean`, whether the layer uses a boundary vector. Default to `True`.
boundary_initializer: Tensors initialized to 0. Default to `zeros`.
use_kernel: `Boolean`, whether the layer uses a kernel weights. Default to `True`.
Call Args:
inputs: Positive integer, dimensionality of the output space.
mask: A boolean `Tensor` of shape `[batch_size, sequence_length]`
or `None`. Default to `None`.
Raises:
ValueError: If input mask doesn't have dim 2 or None.
NotImplementedError: If left padding is provided.
"""
@typechecked
def __init__(
self,
units: int,
chain_initializer: types.Initializer = "orthogonal",
use_boundary: bool = True,
boundary_initializer: types.Initializer = "zeros",
use_kernel: bool = True,
**kwargs,
):
super().__init__(**kwargs)
# setup mask supporting flag, used by base class (the Layer)
# because base class's init method will set it to False unconditionally
# So this assigned must be executed after call base class's init method
self.supports_masking = True
self.units = units # numbers of tags
self.use_boundary = use_boundary
self.use_kernel = use_kernel
self.chain_initializer = tf.keras.initializers.get(chain_initializer)
self.boundary_initializer = tf.keras.initializers.get(boundary_initializer)
# weights that work as transfer probability of each tags
self.chain_kernel = self.add_weight(
shape=(self.units, self.units),
name="chain_kernel",
initializer=self.chain_initializer,
)
# weight of <START> to tag probability and tag to <END> probability
if self.use_boundary:
self.left_boundary = self.add_weight(
shape=(self.units,),
name="left_boundary",
initializer=self.boundary_initializer,
)
self.right_boundary = self.add_weight(
shape=(self.units,),
name="right_boundary",
initializer=self.boundary_initializer,
)
if self.use_kernel:
self._dense_layer = tf.keras.layers.Dense(
units=self.units, dtype=self.dtype
)
else:
self._dense_layer = lambda x: tf.cast(x, dtype=self.dtype)
def call(self, inputs, mask=None):
# mask: Tensor(shape=(batch_size, sequence_length), dtype=bool) or None
if mask is not None:
if tf.keras.backend.ndim(mask) != 2:
raise ValueError("Input mask to CRF must have dim 2 if not None")
if mask is not None:
# left padding of mask is not supported, due the underline CRF function
# detect it and report it to user
left_boundary_mask = self._compute_mask_left_boundary(mask)
first_mask = left_boundary_mask[:, 0]
if first_mask is not None and tf.executing_eagerly():
no_left_padding = tf.math.reduce_all(first_mask)
left_padding = not no_left_padding
if left_padding:
raise NotImplementedError(
"Currently, CRF layer do not support left padding"
)
potentials = self._dense_layer(inputs)
# appending boundary probability info
if self.use_boundary:
potentials = self.add_boundary_energy(
potentials, mask, self.left_boundary, self.right_boundary
)
sequence_length = self._get_sequence_length(inputs, mask)
decoded_sequence, _ = self.get_viterbi_decoding(potentials, sequence_length)
return [decoded_sequence, potentials, sequence_length, self.chain_kernel]
def _get_sequence_length(self, input_, mask):
"""Currently underline CRF fucntion (provided by
tensorflow_addons.text.crf) do not support bi-direction masking (left
padding / right padding), it support right padding by tell it the
sequence length.
this function is compute the sequence length from input and
mask.
"""
if mask is not None:
sequence_length = self.mask_to_sequence_length(mask)
else:
# make a mask tensor from input, then used to generate sequence_length
input_energy_shape = tf.shape(input_)
raw_input_shape = tf.slice(input_energy_shape, [0], [2])
alt_mask = tf.ones(raw_input_shape)
sequence_length = self.mask_to_sequence_length(alt_mask)
return sequence_length
def mask_to_sequence_length(self, mask):
"""compute sequence length from mask."""
sequence_length = tf.reduce_sum(tf.cast(mask, tf.int64), 1)
return sequence_length
@staticmethod
def _compute_mask_right_boundary(mask):
"""input mask: 0011100, output right_boundary: 0000100."""
# shift mask to left by 1: 0011100 => 0111000
offset = 1
left_shifted_mask = tf.concat(
[mask[:, offset:], tf.zeros_like(mask[:, :offset])], axis=1
)
# NOTE: below code is different from keras_contrib
# Original code in keras_contrib:
# end_mask = K.cast(
# K.greater(self.shift_left(mask), mask),
# K.floatx()
# )
# has a bug, confirmed
# by the original keras_contrib maintainer
# Luiz Felix (github: lzfelix),
# 0011100 > 0111000 => 0000100
right_boundary = tf.math.greater(
tf.cast(mask, tf.int32), tf.cast(left_shifted_mask, tf.int32)
)
return right_boundary
@staticmethod
def _compute_mask_left_boundary(mask):
"""input mask: 0011100, output left_boundary: 0010000."""
# shift mask to right by 1: 0011100 => 0001110
offset = 1
right_shifted_mask = tf.concat(
[tf.zeros_like(mask[:, :offset]), mask[:, :-offset]], axis=1
)
# 0011100 > 0001110 => 0010000
left_boundary = tf.math.greater(
tf.cast(mask, tf.int32), tf.cast(right_shifted_mask, tf.int32)
)
return left_boundary
def add_boundary_energy(self, potentials, mask, start, end):
def expand_scalar_to_3d(x):
# expand tensor from shape (x, ) to (1, 1, x)
return tf.reshape(x, (1, 1, -1))
start = tf.cast(expand_scalar_to_3d(start), potentials.dtype)
end = tf.cast(expand_scalar_to_3d(end), potentials.dtype)
if mask is None:
potentials = tf.concat(
[potentials[:, :1, :] + start, potentials[:, 1:, :]], axis=1
)
potentials = tf.concat(
[potentials[:, :-1, :], potentials[:, -1:, :] + end], axis=1
)
else:
mask = tf.keras.backend.expand_dims(tf.cast(mask, start.dtype), axis=-1)
start_mask = tf.cast(self._compute_mask_left_boundary(mask), start.dtype)
end_mask = tf.cast(self._compute_mask_right_boundary(mask), end.dtype)
potentials = potentials + start_mask * start
potentials = potentials + end_mask * end
return potentials
def get_viterbi_decoding(self, potentials, sequence_length):
# decode_tags: A [batch_size, max_seq_len] matrix, with dtype `tf.int32`
decode_tags, best_score = crf_decode(
potentials, self.chain_kernel, sequence_length
)
return decode_tags, best_score
def get_config(self):
# used for loading model from disk
config = {
"units": self.units,
"chain_initializer": tf.keras.initializers.serialize(
self.chain_initializer
),
"use_boundary": self.use_boundary,
"boundary_initializer": tf.keras.initializers.serialize(
self.boundary_initializer
),
"use_kernel": self.use_kernel,
}
base_config = super().get_config()
return {**base_config, **config}
def compute_output_shape(self, input_shape):
output_shape = input_shape[:2]
return output_shape
def compute_mask(self, input_, mask=None):
"""keep mask shape [batch_size, max_seq_len]"""
return mask
@property
def _compute_dtype(self):
# fixed output dtype from underline CRF functions
return tf.int32