Skip to content

Commit

Permalink
added TopKSampleEmbeddingHelper
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhitingHu committed Feb 22, 2019
1 parent 4763fa9 commit baa09ff
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 1 deletion.
5 changes: 5 additions & 0 deletions docs/code/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,11 @@ Decoders
.. autoclass:: texar.modules.TransformerDecoderOutput
:members:

:hidden:`TopKSampleEmbeddingHelper`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: texar.modules.TopKSampleEmbeddingHelper
:members:

:hidden:`SoftmaxEmbeddingHelper`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: texar.modules.SoftmaxEmbeddingHelper
Expand Down
86 changes: 85 additions & 1 deletion texar/modules/decoders/rnn_decoder_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import tensorflow as tf
from tensorflow.contrib.seq2seq import TrainingHelper as TFTrainingHelper
from tensorflow.contrib.seq2seq import Helper as TFHelper
from tensorflow.contrib.seq2seq import GreedyEmbeddingHelper
from tensorflow.python.ops.distributions import categorical
from tensorflow.contrib.distributions import RelaxedOneHotCategorical \
as GumbelSoftmax

Expand All @@ -36,8 +38,9 @@
"default_helper_infer_hparams",
"get_helper",
"_get_training_helper",
"TopKSampleEmbeddingHelper",
"GumbelSoftmaxEmbeddingHelper",
"SoftmaxEmbeddingHelper",
"SoftmaxEmbeddingHelper"
]

def default_helper_train_hparams():
Expand Down Expand Up @@ -185,6 +188,87 @@ def _get_training_helper( #pylint: disable=invalid-name
return helper


def _top_k_logits(logits, k):
"""Adapted from
https://github.com/openai/gpt-2/blob/master/src/sample.py#L63-L77
"""
if k == 0:
# no truncation
return logits

def _top_k():
values, _ = tf.nn.top_k(logits, k=k)
min_values = values[:, -1, tf.newaxis]
return tf.where(
logits < min_values,
tf.ones_like(logits, dtype=logits.dtype) * -1e10,
logits,
)
return tf.cond(
tf.equal(k, 0),
lambda: logits,
lambda: _top_k(),
)

class TopKSampleEmbeddingHelper(GreedyEmbeddingHelper):
"""A helper for use during inference.
Samples from `top_k` most likely candidates from a vocab distribution,
and passes the result through an embedding layer to get the next input.
"""

def __init__(self, embedding, start_tokens, end_token, top_k=10,
softmax_temperature=None, seed=None):
"""Initializer.
Args:
embedding: A callable that takes a vector tensor of `ids`
(argmax ids), or the `params` argument for
:tf_main:`tf.nn.embedding_lookup <nn/embedding_lookup>`, or an
instance of subclass of :class:`texar.modules.EmbedderBase`.
The returned tensor will be passed to the decoder input.
start_tokens: `int32` vector shaped `[batch_size]`, the start
tokens.
end_token: `int32` scalar, the token that marks end of decoding.
top_k: `int32` scalar tensor. Number of top candidates to sample
from. Must be `>=0`. If set to 0, samples from all candidates
(i.e., regular random sample decoding).
softmax_temperature (optional): `float32` scalar, value to
divide the logits by before computing the softmax. Larger values
(above 1.0) result in more random samples, while smaller values
push the sampling distribution towards the argmax. Must be
strictly greater than 0. Defaults to 1.0.
seed (optional): The sampling seed.
Raises:
ValueError: if `start_tokens` is not a 1D tensor or `end_token` is
not a scalar.
"""
super(TopKSampleEmbeddingHelper, self).__init__(
embedding, start_tokens, end_token)
self._top_k = top_k
self._softmax_temperature = softmax_temperature
self._seed = seed

def sample(self, time, outputs, state, name=None):
"""sample for SampleEmbeddingHelper."""
del time, state # unused by sample_fn
# Outputs are logits, we sample from the top_k candidates
if not isinstance(outputs, tf.Tensor):
raise TypeError("Expected outputs to be a single Tensor, got: %s" %
type(outputs))
if self._softmax_temperature is None:
logits = outputs
else:
logits = outputs / self._softmax_temperature

logits = _top_k_logits(logits, k=self._top_k)

sample_id_sampler = categorical.Categorical(logits=logits)
sample_ids = sample_id_sampler.sample(seed=self._seed)

return sample_ids

class SoftmaxEmbeddingHelper(TFHelper):
"""A helper that feeds softmax probabilities over vocabulary
to the next step.
Expand Down

0 comments on commit baa09ff

Please sign in to comment.