diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 6ce517bc57..e73dc4c651 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -101,3 +101,6 @@ /tensorflow_addons/text/crf*.py @squadrick /tensorflow_addons/text/parse_time*.py @helinwang /tensorflow_addons/text/skip_gram*.py @rahulunair + +/tensorflow_addons/layers/crf*.py @howl-anderson +/tensorflow_addons/losses/crf*.py @howl-anderson diff --git a/design_docs/PoC_of_crf_layer.ipynb b/design_docs/PoC_of_crf_layer.ipynb new file mode 100644 index 0000000000..fb3746f39a --- /dev/null +++ b/design_docs/PoC_of_crf_layer.ipynb @@ -0,0 +1,288 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING: Logging before flag parsing goes to stderr.\n", + "W0211 17:23:38.296061 139691842963264 tpu_cluster_resolver.py:35] Falling back to tensorflow client, its recommended to install the cloud tpu client directly with pip install cloud-tpu-client .\n" + ] + } + ], + "source": [ + "import tensorflow as tf\n", + "from tensorflow.python.keras.testing_utils import layer_test\n", + "\n", + "import numpy as np" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Layer define" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### define the layer" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "@tf.keras.utils.register_keras_serializable(package='dummy-package')\n", + "class DummyLayer(tf.keras.layers.Layer):\n", + " # for each tensor, increase value i for each\n", + " def __init__(self, i=1, trainable=True, name=None, dtype=None, dynamic=False, **kwargs):\n", + " self.i = i\n", + " super().__init__(trainable=trainable, name=name, dtype=dtype, dynamic=dynamic, **kwargs)\n", + "\n", + " def build(self, input_shape):\n", + " self.i_constant = tf.constant(self.i, tf.float32)\n", + "\n", + " super().build(input_shape)\n", + "\n", + " def call(self, inputs, **kwargs):\n", + " output = tf.add(inputs, self.i_constant)\n", + " return output\n", + "\n", + " def get_config(self):\n", + " custom_config = {\n", + " \"i\": self.i\n", + " }\n", + "\n", + " base_config = super().get_config()\n", + "\n", + " config = dict(list(base_config.items()) + list(custom_config.items()))\n", + "\n", + " return config" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### test the layer, make sure it works" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "W0211 17:23:39.144376 139691842963264 training_eager.py:274] The list of trainable weights is empty. Make sure that you are not setting model.trainable to False before compiling the model.\n" + ] + }, + { + "data": { + "text/plain": [ + "array([6., 7., 8.], dtype=float32)" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "layer_test(\n", + " DummyLayer,\n", + " kwargs={\"i\": 5},\n", + " input_data=np.array([1, 2, 3], np.float32),\n", + " expected_output=np.array([6, 7, 8], np.float32),\n", + " expected_output_dtype=tf.float32,\n", + " validate_training=True,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Save and Load" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "input_data = np.array([1, 2, 3], np.float32)\n", + "expected_output = np.array([6, 7, 8], np.float32)\n", + "\n", + "x = tf.keras.layers.Input(shape=input_data.shape[1:], dtype=input_data.dtype)\n", + "\n", + "layer = DummyLayer(i=5, name=\"dummy\")\n", + "\n", + "model = tf.keras.models.Model(x, layer(x))\n", + "model.compile('rmsprop', 'mse')" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Model: \"model_2\"\n", + "_________________________________________________________________\n", + "Layer (type) Output Shape Param # \n", + "=================================================================\n", + "input_2 (InputLayer) [(None,)] 0 \n", + "_________________________________________________________________\n", + "dummy (DummyLayer) (None,) 0 \n", + "=================================================================\n", + "Total params: 0\n", + "Trainable params: 0\n", + "Non-trainable params: 0\n", + "_________________________________________________________________\n" + ] + } + ], + "source": [ + "model.summary()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### save model" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "model.save(\"model.h5\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### load and test it" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "new_model = tf.keras.models.load_model(\"model.h5\")\n", + "assert new_model.predict(np.array([1], np.float32)) == np.array([6], np.float32)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### load and test it with custom_object" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[6.]\n" + ] + }, + { + "ename": "AssertionError", + "evalue": "oops, custom layer instance is not used by model", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mAssertionError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0mexpected\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnew_model\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpredict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfloat32\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mexpected\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 8\u001b[0;31m \u001b[0;32massert\u001b[0m \u001b[0mexpected\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfloat32\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"oops, custom layer instance is not used by model\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;31mAssertionError\u001b[0m: oops, custom layer instance is not used by model" + ] + } + ], + "source": [ + "# NOTE: i is set to 2, not the default value 1 or 5 saved in .h5 file\n", + "layer = DummyLayer(i=2, name=\"dummy\")\n", + "\n", + "new_model = tf.keras.models.load_model(\"model.h5\", custom_objects={\"dummy\": layer})\n", + "\n", + "expected = new_model.predict(np.array([1], np.float32))\n", + "print(expected)\n", + "assert expected == np.array([3], np.float32), \"oops, custom layer instance is not used by model\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### remove model file" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "!rm model.h5" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.8" + }, + "name": "Untitled.ipynb" + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/design_docs/crf.md b/design_docs/crf.md new file mode 100644 index 0000000000..ee3a1af2bd --- /dev/null +++ b/design_docs/crf.md @@ -0,0 +1,153 @@ +# Some technical selection in implementation of CRF layer +## About CRF loss function +currently the crf loss function is desinged as a seperated method/Class. +### Solution 1: standalone loss +In usage it look like below + +```python +from tensorflow_addons.layers import CRF +from tensorflow_addons.losses import crf_loss + +model = Sequential() +model.add(Embedding(3001, 300, mask_zero=True) + +crf = CRF(10) +model.add(crf) + +model.compile('adam', loss=crf_loss) + +model.fit(x, y) +``` + +#### pros #### +the standard way to use loss + +#### cons #### +in the eager mode, there need override a private of base layer to make this solution works. + +code: +```python +def __call__(self, inputs, *args, **kwargs): + outputs = super(CRF, self).__call__(inputs, *args, **kwargs) + + # A hack that add _keras_history to EagerTensor, make it more like normal Tensor + for tensor in tf.nest.flatten(outputs): + if not hasattr(tensor, '_keras_history'): + tensor._keras_history = (self, 0, 0) + + return outputs +``` + +Maybe this patch should submit to tensorflow-core which can also help others to implement a loss function easier for a complicated layer (such like CRF) + +### Solution 2: get from crf layer ### +In usage it look like below + +```python +from tensorflow_addons.layers import CRF + +model = Sequential() +model.add(Embedding(3001, 300, mask_zero=True) + +crf = CRF(10) +model.add(crf) + +crf_loss = crf.get_keras_loss() + +model.compile('adam', loss=crf_loss) + +model.fit(x, y) +``` + +#### pros #### +easy to implement and no more need patch + +#### cons #### + +This solution has a shortage that load model from disk will be difficult. + +##### TensorFlow's default load process don't work ##### + +```python +# Save the model +model.save('path_to_my_model.h5') + +# Recreate the exact same model purely from the file +new_model = keras.models.load_model('path_to_my_model.h5') +``` + +The reason is when Keras core reconstruct the model from disk, it will construct layer and loss from disk independently, so the new loss instance don't have the reference to the new CRF layer instance, therefore the loss function don't work anymore. + +##### A workaround solution (not prefect) ##### +TODO: add a PoC code for this + +This a workaround solution for loading CRF model from disk. + +1. Load the model without compile +```python +new_model = keras.models.load_model('path_to_my_model.h5', compile=Flase) +``` + +2. Get the CRF layer instance +```python +# normally, crf layer is the last layer +crf_layer_instance = new_model.get_layer(index=-1) +``` + +3. Get the CRF loss instance from layer instance +```python +crf_loss_instance = crf_layer_instance.get_keras_loss() +``` + +4. Compile the model +```python +new_model.compile(loss=crf_loss_instance) +``` + +The shortage of this method is user need to add extract code to load the model and all the arguments except the loss passed to model's compile method before will not longer remembered, user need to pass to it again (if their still remember it) + +## About CRF loss + +### Solution 1: inherit from tf.keras.losses.Loss + +#### pros +the recommended way to implement a "normal" loss + +#### cons +according to the code around `tensorflow_core/python/keras/engine/training.py:1651` +`per_sample_losses` returned by `loss_fn.call(y_true, y_pred)` must (or can be converted to) have the same shape with `sample_weight` which default to output `mask` (tensorflow_core/python/keras/engine/training.py:1642) of CRF layer. + +but that is not possible because `per_sample_losses` is a 1d tensor and `mask` of CRF is a 2d tensor. + +One way to fix it is set output `mark` of crf layer to a 1d tensor, which make the mark is considered as not the same meaning as it's name. + +Other way is modified the output of loss class to make `per_sample_losses` to a 2d tensor and properly set the reduce property of the class. It so wired and break the semantic meaning of the interface, should considered to a bad idea. + + +### Solution 2: implement as a function ### + +#### pros #### +This is a old but standard (keras style) way to implement the loss function + +#### cons #### +TensorFlow will convert a loss function into a subclass of `tf.keras.losses.Loss` in `` file by `` (call chain: `tf.keras.Model::compile()` [Line: 314] > `tensorflow/python/keras/engine/training_utils.py::prepare_loss_functions` [Line: 1501] > `tensorflow/python/keras/engine/training_utils.py::get_loss_function` [Line: 1186]). + +```python + # For losses which are given as strings/functions in the compile API, + # we always set the loss reduction type to be `SUM_OVER_BATCH_SIZE` + # (both in distribution strategy context and otherwise). + return losses.LossFunctionWrapper( + loss_fn, + name=loss_fn.__name__, + reduction=losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE) +``` + +So it has same issue that solution 1. + +### Solution 3: implement loss as a callable class + +#### pros +Nothing breaks. `mark` property is still a meaningful tensor which standard as a mark. + +#### cons +this solution need understanding how keras process a loss function, which is not documented and not recommend way in TF 2.x. diff --git a/tensorflow_addons/layers/BUILD b/tensorflow_addons/layers/BUILD index 2ee52d4f6b..98c27ba44e 100644 --- a/tensorflow_addons/layers/BUILD +++ b/tensorflow_addons/layers/BUILD @@ -6,6 +6,7 @@ py_library( name = "layers", srcs = [ "__init__.py", + "crf.py", "gelu.py", "maxout.py", "multihead_attention.py", @@ -23,10 +24,34 @@ py_library( ], deps = [ "//tensorflow_addons/activations", + "//tensorflow_addons/text", "//tensorflow_addons/utils", ], ) +py_library( + name = "crf", + srcs = [ + "crf.py", + ], + deps = [ + "//tensorflow_addons/text:crf", + "//tensorflow_addons/utils", + ], +) + +py_test( + name = "crf_test", + size = "small", + srcs = [ + "crf_test.py", + ], + main = "crf_test.py", + deps = [ + "//tensorflow_addons/layers:crf", + ], +) + py_test( name = "polynomial_test", size = "small", diff --git a/tensorflow_addons/layers/__init__.py b/tensorflow_addons/layers/__init__.py index 9606352b2d..454bfd9b23 100644 --- a/tensorflow_addons/layers/__init__.py +++ b/tensorflow_addons/layers/__init__.py @@ -14,6 +14,7 @@ # ============================================================================== """Additional layers that conform to Keras API.""" +from tensorflow_addons.layers.crf import CRF from tensorflow_addons.layers.gelu import GELU from tensorflow_addons.layers.maxout import Maxout from tensorflow_addons.layers.multihead_attention import MultiHeadAttention diff --git a/tensorflow_addons/layers/crf.py b/tensorflow_addons/layers/crf.py new file mode 100644 index 0000000000..29acd7a8cf --- /dev/null +++ b/tensorflow_addons/layers/crf.py @@ -0,0 +1,445 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Orginal implementation from keras_contrib/layers/crf +# ============================================================================== +"""Implementing Conditional Random Field layer.""" + +from __future__ import absolute_import, division, print_function + +import tensorflow as tf +from typeguard import typechecked + +from tensorflow_addons.text.crf import crf_decode, crf_log_likelihood +from tensorflow_addons.utils import types + + +@tf.keras.utils.register_keras_serializable(package="Addons") +class CRF(tf.keras.layers.Layer): + """Linear chain conditional random field (CRF). + + Examples: + + ```python + from tensorflow_addons.layers import CRF + from tensorflow_addons.losses import crf_loss + + model = Sequential() + model.add(Embedding(3001, 300, mask_zero=True) + + crf = CRF(10) + model.add(crf) + + model.compile('adam', loss=crf_loss) + + model.fit(x, y) + ``` + + Arguments: + units: Positive integer, dimensionality of the output space, + should equal to tag num. + chain_initializer: Initializer for the `chain_kernel` weights matrix, + used for the CRF chain energy. + (see [initializers](../initializers.md)). + chain_regularizer: Regularizer function applied to + the `chain_kernel` weights matrix. + chain_constraint: Constraint function applied to + the `chain_kernel` weights matrix. + use_boundary: Boolean (default True), indicating if trainable + start-end chain energies should be added to model. + boundary_initializer: Initializer for the `left_boundary`, + 'right_boundary' weights vectors, + used for the start/left and end/right boundary energy. + boundary_regularizer: Regularizer function applied to + the 'left_boundary', 'right_boundary' weight vectors. + boundary_constraint: Constraint function applied to + the `left_boundary`, `right_boundary` weights vectors. + use_kernel: Boolean (default True), indicating if apply + a fully connected layer before CRF op. + kernel_initializer: Initializer for the `kernel` weights matrix, + used for the linear transformation of the inputs. + kernel_regularizer: Regularizer function applied to + the `kernel` weights matrix. + kernel_constraint: Constraint function applied to + the `kernel` weights matrix. + use_bias: Boolean (default True), whether the layer uses a bias vector. + bias_initializer: Initializer for the bias vector. + bias_regularizer: Regularizer function applied to the bias vector. + bias_constraint: Constraint function applied to the bias vector. + activation: default value is 'linear', Activation function to use. + + Input shape: + 3D tensor with shape: `(batch_size, sequence_length, feature_size)`. + + Output shape: + 2D tensor (dtype: int32) with shape: `(batch_size, sequence_length)`. + + Masking: + This layer supports masking + (2D tensor, shape: `(batch_size, sequence_length)`) + for input data with a variable number of timesteps. + This layer output same make tensor, + NOTICE this may cause issue when you + use some keras loss and metrics function which usually expect 1D mask. + + Loss function: + Due to the TF 2.0 version support eager execution be default, + there is no way can implement CRF loss as independent loss function. + Thus, user should use loss method of this layer. + See Examples (above) for detailed usage. + + References: + - [Conditional Random Field](https://en.wikipedia.org/wiki/Conditional_random_field) + """ + + @typechecked + def __init__( + self, + units: int, + chain_initializer: types.Initializer = "orthogonal", + chain_regularizer: types.Regularizer = None, + chain_constraint: types.Constraint = None, + use_boundary: bool = True, + boundary_initializer: types.Initializer = "zeros", + boundary_regularizer: types.Regularizer = None, + boundary_constraint: types.Constraint = None, + use_kernel: bool = True, + kernel_initializer: types.Initializer = "glorot_uniform", + kernel_regularizer: types.Regularizer = None, + kernel_constraint: types.Constraint = None, + use_bias: bool = True, + bias_initializer: types.Initializer = "zeros", + bias_regularizer: types.Regularizer = None, + bias_constraint: types.Constraint = None, + activation: types.Activation = "linear", + **kwargs + ): + super(CRF, self).__init__(**kwargs) + + # setup mask supporting flag, used by base class (the Layer) + # because base class's init method will set it to False unconditionally + # So this assigned must be executed after call base class's init method + self.supports_masking = True + + self.units = units # numbers of tags + + self.use_boundary = use_boundary + self.use_bias = use_bias + self.use_kernel = use_kernel + + self.activation = tf.keras.activations.get(activation) + + self.kernel_initializer = tf.keras.initializers.get(kernel_initializer) + self.chain_initializer = tf.keras.initializers.get(chain_initializer) + self.boundary_initializer = tf.keras.initializers.get(boundary_initializer) + self.bias_initializer = tf.keras.initializers.get(bias_initializer) + + self.kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer) + self.chain_regularizer = tf.keras.regularizers.get(chain_regularizer) + self.boundary_regularizer = tf.keras.regularizers.get(boundary_regularizer) + self.bias_regularizer = tf.keras.regularizers.get(bias_regularizer) + + self.kernel_constraint = tf.keras.constraints.get(kernel_constraint) + self.chain_constraint = tf.keras.constraints.get(chain_constraint) + self.boundary_constraint = tf.keras.constraints.get(boundary_constraint) + self.bias_constraint = tf.keras.constraints.get(bias_constraint) + + # values will be assigned in method + self.input_spec = None + + # value remembered for loss/metrics function + self.potentials = None + self.sequence_length = None + self.mask = None + + # global variable + self.chain_kernel = None + self._dense_layer = None + self.left_boundary = None + self.right_boundary = None + + def build(self, input_shape): + input_shape = tuple(tf.TensorShape(input_shape).as_list()) + + # see API docs of InputSpec for more detail + self.input_spec = [tf.keras.layers.InputSpec(shape=input_shape)] + + # weights that work as transfer probability of each tags + self.chain_kernel = self.add_weight( + shape=(self.units, self.units), + name="chain_kernel", + initializer=self.chain_initializer, + regularizer=self.chain_regularizer, + constraint=self.chain_constraint, + ) + + # weight of to tag probability and tag to probability + if self.use_boundary: + self.left_boundary = self.add_weight( + shape=(self.units,), + name="left_boundary", + initializer=self.boundary_initializer, + regularizer=self.boundary_regularizer, + constraint=self.boundary_constraint, + ) + self.right_boundary = self.add_weight( + shape=(self.units,), + name="right_boundary", + initializer=self.boundary_initializer, + regularizer=self.boundary_regularizer, + constraint=self.boundary_constraint, + ) + + if self.use_kernel: + self._dense_layer = tf.keras.layers.Dense( + units=self.units, + activation=self.activation, + use_bias=self.use_bias, + bias_initializer=self.bias_initializer, + kernel_regularizer=self.kernel_regularizer, + bias_regularizer=self.bias_regularizer, + kernel_constraint=self.kernel_constraint, + bias_constraint=self.bias_constraint, + dtype=self.dtype, + ) + else: + self._dense_layer = lambda x: tf.cast(x, dtype=self.dtype) + + super(CRF, self).build(input_shape) + + def call(self, inputs, mask=None, **kwargs): + # mask: Tensor(shape=(batch_size, sequence_length), dtype=bool) or None + + if mask is not None: + if tf.keras.backend.ndim(mask) != 2: + raise ValueError("Input mask to CRF must have dim 2 if not None") + + # left padding of mask is not supported, due the underline CRF function + # detect it and report it to user + first_mask = None + if mask is not None: + left_boundary_mask = self._compute_mask_left_boundary(mask) + first_mask = left_boundary_mask[:, 0] + + # remember this value for later use + self.mask = mask + + if first_mask is not None: + no_left_padding = tf.math.reduce_all(first_mask) + msg = "Currently, CRF layer do not support left padding" + with tf.control_dependencies( + [ + tf.debugging.assert_equal( + no_left_padding, tf.constant(True), message=msg + ) + ] + ): + self.potentials = self._dense_layer(inputs) + else: + self.potentials = self._dense_layer(inputs) + + # appending boundary probability info + if self.use_boundary: + self.potentials = self.add_boundary_energy( + self.potentials, mask, self.left_boundary, self.right_boundary + ) + + self.sequence_length = self._get_sequence_length(inputs, mask) + + decoded_sequence, _ = self.get_viterbi_decoding( + self.potentials, self.sequence_length + ) + + return decoded_sequence + + def _get_sequence_length(self, input_, mask): + """Currently underline CRF fucntion (provided by + tensorflow_addons.text.crf) do not support bi-direction masking (left + padding / right padding), it support right padding by tell it the + sequence length. + + this function is compute the sequence length from input and + mask. + """ + if mask is not None: + sequence_length = self.mask_to_sequence_length(mask) + else: + # make a mask tensor from input, then used to generate sequence_length + input_energy_shape = tf.shape(input_) + raw_input_shape = tf.slice(input_energy_shape, [0], [2]) + alt_mask = tf.ones(raw_input_shape) + + sequence_length = self.mask_to_sequence_length(alt_mask) + + return sequence_length + + def mask_to_sequence_length(self, mask): + """compute sequence length from mask.""" + sequence_length = tf.cast(tf.reduce_sum(tf.cast(mask, tf.int8), 1), tf.int64) + return sequence_length + + @staticmethod + def _compute_mask_right_boundary(mask): + """input mask: 0011100, output left_boundary: 0000100.""" + # shift mask to left by 1: 0011100 => 0111000 + offset = 1 + left_shifted_mask = tf.concat( + [mask[:, offset:], tf.zeros_like(mask[:, :offset])], axis=1 + ) + + # NOTE: below code is different from keras_contrib + # Original code in keras_contrib: + # end_mask = K.cast( + # K.greater(self.shift_left(mask), mask), + # K.floatx() + # ) + # has a bug, confirmed + # by the original keras_contrib maintainer + # Luiz Felix (github: lzfelix), + + # 0011100 > 0111000 => 0000100 + right_boundary = tf.greater(mask, left_shifted_mask) + + return right_boundary + + @staticmethod + def _compute_mask_left_boundary(mask): + """input mask: 0011100, output left_boundary: 0010000.""" + # shift mask to right by 1: 0011100 => 0001110 + offset = 1 + right_shifted_mask = tf.concat( + [tf.zeros_like(mask[:, :offset]), mask[:, :-offset]], axis=1 + ) + + # 0011100 > 0001110 => 0010000 + left_boundary = tf.greater( + tf.cast(mask, tf.int32), tf.cast(right_shifted_mask, tf.int32) + ) + # left_boundary = tf.greater(mask, right_shifted_mask) + + return left_boundary + + def add_boundary_energy(self, potentials, mask, start, end): + def expand_scalar_to_3d(x): + # expand tensor from shape (x, ) to (1, 1, x) + return tf.reshape(x, (1, 1, -1)) + + start = expand_scalar_to_3d(start) + end = expand_scalar_to_3d(end) + if mask is None: + potentials = tf.concat( + [potentials[:, :1, :] + start, potentials[:, 1:, :]], axis=1 + ) + potentials = tf.concat( + [potentials[:, :-1, :], potentials[:, -1:, :] + end], axis=1 + ) + else: + mask = tf.keras.backend.expand_dims(tf.cast(mask, start.dtype), axis=-1) + start_mask = tf.cast(self._compute_mask_left_boundary(mask), start.dtype) + + end_mask = tf.cast(self._compute_mask_right_boundary(mask), end.dtype) + potentials = potentials + start_mask * start + potentials = potentials + end_mask * end + return potentials + + def get_viterbi_decoding(self, potentials, sequence_length): + # decode_tags: A [batch_size, max_seq_len] matrix, with dtype `tf.int32` + decode_tags, best_score = crf_decode( + potentials, self.chain_kernel, sequence_length + ) + + return decode_tags, best_score + + def get_config(self): + # used for loading model from disk + config = { + "units": self.units, + "use_boundary": self.use_boundary, + "use_bias": self.use_bias, + "use_kernel": self.use_kernel, + "kernel_initializer": tf.keras.initializers.serialize( + self.kernel_initializer + ), + "chain_initializer": tf.keras.initializers.serialize( + self.chain_initializer + ), + "boundary_initializer": tf.keras.initializers.serialize( + self.boundary_initializer + ), + "bias_initializer": tf.keras.initializers.serialize(self.bias_initializer), + "activation": tf.keras.activations.serialize(self.activation), + "kernel_regularizer": tf.keras.regularizers.serialize( + self.kernel_regularizer + ), + "chain_regularizer": tf.keras.regularizers.serialize( + self.chain_regularizer + ), + "boundary_regularizer": tf.keras.regularizers.serialize( + self.boundary_regularizer + ), + "bias_regularizer": tf.keras.regularizers.serialize(self.bias_regularizer), + "kernel_constraint": tf.keras.constraints.serialize(self.kernel_constraint), + "chain_constraint": tf.keras.constraints.serialize(self.chain_constraint), + "boundary_constraint": tf.keras.constraints.serialize( + self.boundary_constraint + ), + "bias_constraint": tf.keras.constraints.serialize(self.bias_constraint), + } + base_config = super(CRF, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + def compute_output_shape(self, input_shape): + output_shape = input_shape[:2] + return output_shape + + def compute_mask(self, input_, mask=None): + """keep mask shape [batch_size, max_seq_len]""" + return mask + + def get_negative_log_likelihood(self, y_true): + y_true = tf.cast(y_true, tf.int32) + self.sequence_length = tf.cast(self.sequence_length, tf.int32) + + log_likelihood, _ = crf_log_likelihood( + self.potentials, y_true, self.sequence_length, self.chain_kernel + ) + + return -log_likelihood + + def get_loss(self, y_true, y_pred): + # we don't use y_pred, but caller pass it anyway, ignore it + return self.get_negative_log_likelihood(y_true) + + def get_accuracy(self, y_true, y_pred): + judge = tf.cast(tf.equal(y_pred, y_true), tf.keras.backend.floatx()) + if self.mask is None: + return tf.reduce_mean(judge) + else: + mask = tf.cast(self.mask, tf.keras.backend.floatx()) + return tf.reduce_sum(judge * mask) / tf.reduce_sum(mask) + + def __call__(self, inputs, *args, **kwargs): + outputs = super(CRF, self).__call__(inputs, *args, **kwargs) + + # A hack that add _keras_history to EagerTensor, make it more like normal Tensor + for tensor in tf.nest.flatten(outputs): + if not hasattr(tensor, "_keras_history"): + tensor._keras_history = (self, 0, 0) + + return outputs + + @property + def _compute_dtype(self): + # fixed output dtype from underline CRF functions + return tf.int32 diff --git a/tensorflow_addons/layers/crf_test.py b/tensorflow_addons/layers/crf_test.py new file mode 100644 index 0000000000..9910b391be --- /dev/null +++ b/tensorflow_addons/layers/crf_test.py @@ -0,0 +1,70 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Conditional Random Field layer.""" + +from __future__ import absolute_import, division, print_function + +import numpy as np +import tensorflow as tf + +from tensorflow_addons.layers.crf import CRF +from tensorflow_addons.utils import test_utils + + +@test_utils.run_all_in_graph_and_eager_modes +class TestCRF(tf.test.TestCase): + def test_unmasked_viterbi_decode(self): + x = np.array( + [ + [ + # O B-X I-X B-Y I-Y + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.0], + ], + [ + # O B-X I-X B-Y I-Y + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + ], + ] + ) # yapf: disable + + expected_y = np.array( + [[1, 2, 2], [1, 1, 1]] # B-X I-X I-X # B-X B-X B-X + ) # yapf: disable + + transitions = np.ones([5, 5]) + boundary_value = np.ones(5) + + test_utils.layer_test( + CRF, + kwargs={ + "units": 5, + "use_kernel": False, # disable kernel transform + "chain_initializer": tf.keras.initializers.Constant(transitions), + "use_boundary": True, + "boundary_initializer": tf.keras.initializers.Constant(boundary_value), + }, + input_data=x, + expected_output=expected_y, + expected_output_dtype=tf.int32, + validate_training=False, + ) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow_addons/losses/BUILD b/tensorflow_addons/losses/BUILD index 085c08aa6b..306115e894 100644 --- a/tensorflow_addons/losses/BUILD +++ b/tensorflow_addons/losses/BUILD @@ -7,6 +7,7 @@ py_library( srcs = [ "__init__.py", "contrastive.py", + "crf.py", "focal_loss.py", "giou_loss.py", "lifted.py", @@ -18,6 +19,18 @@ py_library( ], deps = [ "//tensorflow_addons/activations", + "//tensorflow_addons/layers", + "//tensorflow_addons/utils", + ], +) + +py_library( + name = "crf", + srcs = [ + "crf.py", + ], + deps = [ + "//tensorflow_addons/layers:crf", "//tensorflow_addons/utils", ], ) @@ -34,6 +47,20 @@ py_test( ], ) +py_test( + name = "crf_test", + size = "small", + srcs = [ + "crf_test.py", + ], + main = "crf_test.py", + srcs_version = "PY2AND3", + deps = [ + "//tensorflow_addons/layers:crf", + "//tensorflow_addons/losses:crf", + ], +) + py_test( name = "focal_loss_test", size = "small", diff --git a/tensorflow_addons/losses/__init__.py b/tensorflow_addons/losses/__init__.py index 4af2be16e0..8cda01fa5a 100644 --- a/tensorflow_addons/losses/__init__.py +++ b/tensorflow_addons/losses/__init__.py @@ -15,6 +15,7 @@ """Additional losses that conform to Keras API.""" from tensorflow_addons.losses.contrastive import contrastive_loss, ContrastiveLoss +from tensorflow_addons.losses.crf import crf_loss, ConditionalRandomFieldLoss from tensorflow_addons.losses.focal_loss import ( sigmoid_focal_crossentropy, SigmoidFocalCrossEntropy, diff --git a/tensorflow_addons/losses/crf.py b/tensorflow_addons/losses/crf.py new file mode 100644 index 0000000000..4eb3fcaaff --- /dev/null +++ b/tensorflow_addons/losses/crf.py @@ -0,0 +1,46 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Implementing Conditional Random Field loss.""" + +from __future__ import absolute_import, division, print_function + +import tensorflow as tf + +from tensorflow_addons.layers.crf import CRF + + +@tf.keras.utils.register_keras_serializable(package="Addons") +class ConditionalRandomFieldLoss(object): + def __init__(self, name: str = "crf_loss"): + self.name = name + + def get_config(self): + return {"name": self.name} + + def __call__(self, y_true, y_pred, sample_weight=None): + crf_layer = y_pred._keras_history[0] + + # check if last layer is CRF + if not isinstance(crf_layer, CRF): + raise ValueError( + "Last layer must be CRF for use {}.".format(self.__class__.__name__) + ) + + loss_vector = crf_layer.get_loss(y_true, y_pred) + + return tf.keras.backend.mean(loss_vector) + + +crf_loss = ConditionalRandomFieldLoss() diff --git a/tensorflow_addons/losses/crf_test.py b/tensorflow_addons/losses/crf_test.py new file mode 100644 index 0000000000..b32dcdec80 --- /dev/null +++ b/tensorflow_addons/losses/crf_test.py @@ -0,0 +1,347 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Conditional Random Field loss.""" + +import itertools +import math +import os + +import numpy as np +import tensorflow as tf +from tensorflow.python.framework import tensor_util +from tensorflow.python.keras.engine import base_layer_utils +from tensorflow.python.util import nest + +from tensorflow_addons.layers.crf import CRF +from tensorflow_addons.losses import crf +from tensorflow_addons.utils import test_utils + +from unittest.mock import patch + +CRF_LOSS_OBJ_LIST = [crf.crf_loss, crf.ConditionalRandomFieldLoss()] + + +@test_utils.run_all_in_graph_and_eager_modes +class ConditionalRandomFieldLossTest(tf.test.TestCase): + def setUp(self): + super().setUp() + + self.logits = np.array( + [ + [[0, 0, 0.5, 0.5, 0.2], [0, 0, 0.3, 0.3, 0.1], [0, 0, 0.9, 10, 1]], + [[0, 0, 0.2, 0.5, 0.2], [0, 0, 3, 0.3, 0.1], [0, 0, 0.9, 1, 1]], + ] + ) + self.tags = np.array([[2, 3, 4], [3, 2, 2]]) + + self.transitions = np.array( + [ + [0.1, 0.2, 0.3, 0.4, 0.5], + [0.8, 0.3, 0.1, 0.7, 0.9], + [-0.3, 2.1, -5.6, 3.4, 4.0], + [0.2, 0.4, 0.6, -0.3, -0.4], + [1.0, 1.0, 1.0, 1.0, 1.0], + ] + ) + + self.boundary_values = np.ones((5,)) + + # Use the CRF Module with fixed transitions to compute the log_likelihood + self.crf = CRF( + units=5, + use_kernel=False, # disable kernel transform + chain_initializer=tf.keras.initializers.Constant(self.transitions), + use_boundary=True, + boundary_initializer=tf.keras.initializers.Constant(self.boundary_values), + name="crf_layer", + ) + + def score(self, logits, tags): + """Computes the likelihood score for the given sequence of tags, given + the provided logits (and the transition weights in the CRF model)""" + # Start with transitions from START and to END + total = self.boundary_values[tags[0]] + self.boundary_values[tags[-1]] + # Add in all the intermediate transitions + for tag, next_tag in zip(tags, tags[1:]): + total += self.transitions[tag, next_tag] + # Add in the logits for the observed tags + for logit, tag in zip(logits, tags): + total += logit[tag] + return total + + def compute_log_likelihood(self): + # Now compute the log-likelihood manually + manual_log_likelihood = 0.0 + + # For each instance, manually compute the numerator + # (which is just the score for the logits and actual tags) + # and the denominator + # (which is the log-sum-exp of the scores + # for the logits across all possible tags) + for logits_i, tags_i in zip(self.logits, self.tags): + numerator = self.score(logits_i, tags_i) + all_scores = [ + self.score(logits_i, tags_j) + for tags_j in itertools.product(range(5), repeat=3) + ] + denominator = math.log(sum(math.exp(score) for score in all_scores)) + # And include them in the manual calculation. + manual_log_likelihood += numerator - denominator + + return manual_log_likelihood + + def _test_loss_function(self, loss_obj): + model = tf.keras.models.Sequential() + model.add(tf.keras.layers.Input(shape=(3, 5))) + model.add(self.crf) + model.compile("adam", loss=loss_obj, metrics=[tf.keras.metrics.Accuracy()]) + + log_likelihood, _ = model.train_on_batch(self.logits, self.tags) + + # The manually computed log likelihood should + # equal the result of crf.forward. + expected_log_likelihood = self.compute_log_likelihood() + unbatched_log_likelihood = -2 * log_likelihood + + self.assertAllClose(expected_log_likelihood, unbatched_log_likelihood) + + def test_class_loss_function(self): + self._test_loss_function(crf.ConditionalRandomFieldLoss()) + + def test_func_loss_function(self): + self._test_loss_function(crf.crf_loss) + + def test_model_fit(self): + for loss_obj in CRF_LOSS_OBJ_LIST: + with self.subTest(loss_obj=loss_obj): + model = tf.keras.models.Sequential() + model.add(tf.keras.layers.Input(shape=(3, 5))) + model.add(self.crf) + model.compile( + "adam", loss=loss_obj, metrics=[tf.keras.metrics.Accuracy()] + ) + + model.fit(self.logits, self.tags, epochs=10, batch_size=1) + + def _test_dump_and_load(self, loss_obj): + tmp_dir = self.get_temp_dir() + MODEL_PERSISTENCE_PATH = os.path.join(tmp_dir, "test_saving_crf_model.h5") + + model = tf.keras.models.Sequential() + model.add(tf.keras.layers.Input(shape=(3, 5))) + model.add(self.crf) + model.compile("adam", loss=loss_obj, metrics=[tf.keras.metrics.Accuracy()]) + + model.fit(self.logits, self.tags, epochs=10, batch_size=1) + + model.save(MODEL_PERSISTENCE_PATH) + + # no news is good news + new_model = tf.keras.models.load_model(MODEL_PERSISTENCE_PATH) + new_model.fit(self.logits, self.tags, epochs=10, batch_size=1) + + try: + os.remove(MODEL_PERSISTENCE_PATH) + except OSError: + pass + + def test_dump_and_load_with_class_loss(self): + # TODO(howl-anderson): wait for the PR merged + self.skipTest("require tensorflow/tensorflow#37018 merged") + + self._test_dump_and_load(crf.ConditionalRandomFieldLoss()) + + def test_mask_left_padding(self): + for loss_obj in CRF_LOSS_OBJ_LIST: + with self.subTest(loss_obj=loss_obj): + train_x = np.array( + [ + [ + # O B-X I-X B-Y I-Y + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.0], + ], + [ + # O B-X I-X B-Y I-Y + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + ], + ] + ) # yapf: disable + + train_y = np.array( + [[1, 2, 2], [1, 1, 1]] # B-X I-X I-X # B-X B-X B-X + ) # yapf: disable + + mask = np.array([[0, 1, 1], [1, 1, 1]]) + + layer = CRF(5) + + x = tf.keras.layers.Input(shape=(3, 5)) + y = layer(x, mask=tf.constant(mask)) + + # check shape inference + model = tf.keras.models.Model(x, y) + model.compile("adam", loss_obj) + + with self.assertRaises(tf.errors.InvalidArgumentError) as context: + model.fit(train_x, train_y) + + self.assertTrue( + "CRF layer do not support left padding" in context.exception.message + ) + + def test_mask_right_padding(self): + for loss_obj in CRF_LOSS_OBJ_LIST: + with self.subTest(loss_obj=loss_obj): + train_x = np.array( + [ + [ + # O B-X I-X B-Y I-Y + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.0], + ], + [ + # O B-X I-X B-Y I-Y + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + ], + ] + ) # yapf: disable + + train_y = np.array( + [[1, 2, 2], [1, 1, 1]] # B-X I-X I-X # B-X B-X B-X + ) # yapf: disable + + mask = np.array([[1, 1, 1], [1, 1, 0]]) + + layer = CRF(5) + + x = tf.keras.layers.Input(shape=(3, 5)) + y = layer(x, mask=tf.constant(mask)) + + # check shape inference + model = tf.keras.models.Model(x, y) + model.compile("adam", loss_obj) + model.fit(train_x, train_y) + + def test_in_subclass_model(self): + for loss_obj in CRF_LOSS_OBJ_LIST: + with self.subTest(loss_obj=loss_obj): + train_x = np.array( + [ + [ + # O B-X I-X B-Y I-Y + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.0], + ], + [ + # O B-X I-X B-Y I-Y + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + ], + ] + ) # yapf: disable + + train_y = np.array( + [[1, 2, 2], [1, 1, 1]] # B-X I-X I-X # B-X B-X B-X + ) # yapf: disable + + def patch_mark_as_return(outputs, acd): + """Marks `outputs` as the return values for automatic control + deps.""" + + def _mark_as_return(tensor): + """Marks `tensor` as the return value for automatic control + deps.""" + if not tensor_util.is_tensor(tensor): + return tensor + + # pylint: disable=protected-access + return_tensor = acd.mark_as_return(tensor) + if getattr(tensor, "_keras_mask", None) is not None: + return_tensor._keras_mask = acd.mark_as_return( + tensor._keras_mask + ) + else: + return_tensor._keras_mask = None + + # TODO(howl-anderson) a little hack here, handle _keras_history + if getattr(tensor, "_keras_history", None) is not None: + return_tensor._keras_history = tensor._keras_history + + # Handle TensorFlow Probability attached metadata. + # TODO(b/132076537): Remove this once TFP uses `CompositeTensor`. + if getattr(tensor, "_tfp_distribution", None) is not None: + return_tensor._tfp_distribution = tensor._tfp_distribution + + return return_tensor + # pylint: enable=protected-access + + return nest.map_structure(_mark_as_return, outputs) + + class CRFModel(tf.keras.Model): + def __init__(self): + super().__init__() + + self.layer = CRF(5) + + def call(self, inputs): + return self.layer(inputs) + + @patch.object( + base_layer_utils, "mark_as_return", patch_mark_as_return + ) + def __call__(self, inputs, *args, **kwargs): + outputs = super().__call__(inputs, *args, **kwargs) + + # A hack that add _keras_history to EagerTensor, make it more like normal Tensor + for tensor in tf.nest.flatten(outputs): + if not hasattr(tensor, "_keras_history"): + tensor._keras_history = (self, 0, 0) + + return outputs + + model = CRFModel() + + model.compile("adam", loss_obj) + model.fit(train_x, train_y) + + def test_serialization(self): + for loss_obj in CRF_LOSS_OBJ_LIST: + with self.subTest(loss_obj=loss_obj): + ref_fn = loss_obj + config = tf.keras.losses.serialize(ref_fn) + fn = tf.keras.losses.deserialize(config) + self.assertEqual(ref_fn.get_config(), fn.get_config()) + + def test_keras_model_compile(self): + for loss_obj in CRF_LOSS_OBJ_LIST: + with self.subTest(loss_obj=loss_obj): + model = tf.keras.models.Sequential( + [tf.keras.layers.Input(shape=(3, 5)), self.crf] + ) + + model.compile(loss=loss_obj, optimizer="adam") + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow_addons/text/BUILD b/tensorflow_addons/text/BUILD index 130ca77909..30978e51ed 100644 --- a/tensorflow_addons/text/BUILD +++ b/tensorflow_addons/text/BUILD @@ -25,6 +25,13 @@ py_library( }), ) +py_library( + name = "crf", + srcs = [ + "crf.py", + ], +) + py_test( name = "crf_test", size = "small",