Skip to content

Commit

Permalink
[FEATURE] Flexible vocabulary (dmlc#732)
Browse files Browse the repository at this point in the history
- Allow specification of special tokens as keyword arguments
- Allow users to specify the order of the vocabulary index
  • Loading branch information
leezu authored Jun 6, 2019
1 parent 4f89ca0 commit 1e50a66
Show file tree
Hide file tree
Showing 11 changed files with 580 additions and 278 deletions.
14 changes: 7 additions & 7 deletions scripts/bert/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,11 @@ def swap(token, target_idx, token_to_idx, idx_to_token, swap_idx):
idx_to_token[original_idx] = original_token
swap_idx.append((original_idx, target_idx))

reserved_tokens = [gluonnlp.vocab.BERTVocab.PADDING_TOKEN, gluonnlp.vocab.BERTVocab.CLS_TOKEN,
gluonnlp.vocab.BERTVocab.SEP_TOKEN, gluonnlp.vocab.BERTVocab.MASK_TOKEN]
reserved_tokens = [gluonnlp.vocab.bert.PADDING_TOKEN, gluonnlp.vocab.bert.CLS_TOKEN,
gluonnlp.vocab.bert.SEP_TOKEN, gluonnlp.vocab.bert.MASK_TOKEN]

unknown_token = gluonnlp.vocab.BERTVocab.UNKNOWN_TOKEN
padding_token = gluonnlp.vocab.BERTVocab.PADDING_TOKEN
unknown_token = gluonnlp.vocab.bert.UNKNOWN_TOKEN
padding_token = gluonnlp.vocab.bert.PADDING_TOKEN
swap_idx = []
assert unknown_token in token_to_idx
assert padding_token in token_to_idx
Expand All @@ -75,9 +75,9 @@ def swap(token, target_idx, token_to_idx, idx_to_token, swap_idx):
bert_vocab_dict['padding_token'] = padding_token
bert_vocab_dict['bos_token'] = None
bert_vocab_dict['eos_token'] = None
bert_vocab_dict['mask_token'] = gluonnlp.vocab.BERTVocab.MASK_TOKEN
bert_vocab_dict['sep_token'] = gluonnlp.vocab.BERTVocab.SEP_TOKEN
bert_vocab_dict['cls_token'] = gluonnlp.vocab.BERTVocab.CLS_TOKEN
bert_vocab_dict['mask_token'] = gluonnlp.vocab.bert.MASK_TOKEN
bert_vocab_dict['sep_token'] = gluonnlp.vocab.bert.SEP_TOKEN
bert_vocab_dict['cls_token'] = gluonnlp.vocab.bert.CLS_TOKEN
json_str = json.dumps(bert_vocab_dict)
converted_vocab = gluonnlp.vocab.BERTVocab.from_json(json_str)
return converted_vocab, swap_idx
Expand Down
3 changes: 3 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ def find_version(*file_paths):
],
'dev': [
'pytest',
'pylint',
'pylint_quotes',
'flake8',
'recommonmark',
'sphinx-gallery',
'sphinx_rtd_theme',
Expand Down
4 changes: 0 additions & 4 deletions src/gluonnlp/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,6 @@

PAD_TOKEN = '<pad>'

UNK_IDX = 0 # This should not be changed as long as serialized token
# embeddings redistributed on S3 contain an unknown token.
# Blame this code change and see commit for more context.

LARGE_POSITIVE_FLOAT = 1e18

LARGE_NEGATIVE_FLOAT = -LARGE_POSITIVE_FLOAT
Expand Down
2 changes: 1 addition & 1 deletion src/gluonnlp/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
__all__ = ['_str_types', 'numba_njit', 'numba_prange', 'numba_jitclass', 'numba_types',
'get_home_dir']

try:
try: # Python 2 compat
_str_types = (str, unicode)
except NameError: # Python 3
_str_types = (str, )
Expand Down
43 changes: 23 additions & 20 deletions src/gluonnlp/embedding/token_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@
from ..data.utils import DefaultLookupDict
from ..model.train import FasttextEmbeddingModel

UNK_IDX = 0 # This should not be changed as long as serialized token
# embeddings redistributed on S3 contain an unknown token. Commit
# 46eb28ffcf542ab8ed801a727a45404a73adbce3 has more context.

def register(embedding_cls):
"""Registers a new token embedding.
Expand Down Expand Up @@ -188,7 +191,7 @@ def __init__(self, unknown_token='<unk>', init_unknown_vec=nd.zeros,
self._unknown_lookup = unknown_lookup
self._idx_to_token = [unknown_token] if unknown_token else []
if unknown_token:
self._token_to_idx = DefaultLookupDict(C.UNK_IDX)
self._token_to_idx = DefaultLookupDict(UNK_IDX)
else:
self._token_to_idx = {}
self._token_to_idx.update((token, idx) for idx, token in enumerate(self._idx_to_token))
Expand Down Expand Up @@ -316,9 +319,9 @@ def _load_embedding_txt(self, pretrained_file_path, elem_delim, encoding='utf8')

if self.unknown_token:
if loaded_unknown_vec is None:
self._idx_to_vec[C.UNK_IDX] = self._init_unknown_vec(shape=vec_len)
self._idx_to_vec[UNK_IDX] = self._init_unknown_vec(shape=vec_len)
else:
self._idx_to_vec[C.UNK_IDX] = nd.array(loaded_unknown_vec)
self._idx_to_vec[UNK_IDX] = nd.array(loaded_unknown_vec)

def _load_embedding_serialized(self, pretrained_file_path):
"""Load embedding vectors from a pre-trained token embedding file.
Expand All @@ -334,23 +337,23 @@ def _load_embedding_serialized(self, pretrained_file_path):
deserialized_embedding = TokenEmbedding.deserialize(pretrained_file_path)
if deserialized_embedding.unknown_token:
# Some .npz files on S3 may contain an unknown token and its
# respective embedding. As a workaround, we assume that C.UNK_IDX
# respective embedding. As a workaround, we assume that UNK_IDX
# is the same now as it was when the .npz was generated. Under this
# assumption we can safely overwrite the respective token and
# vector from the npz.
if deserialized_embedding.unknown_token:
idx_to_token = deserialized_embedding.idx_to_token
idx_to_vec = deserialized_embedding.idx_to_vec
idx_to_token[C.UNK_IDX] = self.unknown_token
idx_to_token[UNK_IDX] = self.unknown_token
if self._init_unknown_vec:
vec_len = idx_to_vec.shape[1]
idx_to_vec[C.UNK_IDX] = self._init_unknown_vec(shape=vec_len)
idx_to_vec[UNK_IDX] = self._init_unknown_vec(shape=vec_len)
else:
# If the TokenEmbedding shall not have an unknown token, we
# just delete the one in the npz.
assert C.UNK_IDX == 0
idx_to_token = deserialized_embedding.idx_to_token[C.UNK_IDX + 1:]
idx_to_vec = deserialized_embedding.idx_to_vec[C.UNK_IDX + 1:]
assert UNK_IDX == 0
idx_to_token = deserialized_embedding.idx_to_token[UNK_IDX + 1:]
idx_to_vec = deserialized_embedding.idx_to_vec[UNK_IDX + 1:]
else:
idx_to_token = deserialized_embedding.idx_to_token
idx_to_vec = deserialized_embedding.idx_to_vec
Expand All @@ -363,10 +366,10 @@ def _load_embedding_serialized(self, pretrained_file_path):
try:
unknown_token_idx = deserialized_embedding.idx_to_token.index(
self.unknown_token)
idx_to_token[C.UNK_IDX], idx_to_token[
idx_to_token[UNK_IDX], idx_to_token[
unknown_token_idx] = idx_to_token[
unknown_token_idx], idx_to_token[C.UNK_IDX]
idxs = [C.UNK_IDX, unknown_token_idx]
unknown_token_idx], idx_to_token[UNK_IDX]
idxs = [UNK_IDX, unknown_token_idx]
idx_to_vec[idxs] = idx_to_vec[idxs[::-1]]
except ValueError:
vec_len = idx_to_vec.shape[1]
Expand Down Expand Up @@ -595,7 +598,7 @@ def __setitem__(self, tokens, new_embedding):
if ((self.allow_extend or all(t in self.token_to_idx for t in tokens))
and self._idx_to_vec is None):
# Initialize self._idx_to_vec
assert C.UNK_IDX == 0
assert UNK_IDX == 0
self._idx_to_vec = self._init_unknown_vec(
shape=(1, new_embedding.shape[-1]))

Expand Down Expand Up @@ -632,7 +635,7 @@ def __setitem__(self, tokens, new_embedding):
raise KeyError(('Token "{}" is unknown. To update the embedding vector for an'
' unknown token, please explicitly include "{}" as the '
'`unknown_token` in `tokens`. This is to avoid unintended '
'updates.').format(token, self._idx_to_token[C.UNK_IDX]))
'updates.').format(token, self._idx_to_token[UNK_IDX]))
else:
raise KeyError(('Token "{}" is unknown. Updating the embedding vector for an '
'unknown token is not allowed because `unknown_token` is not '
Expand Down Expand Up @@ -727,7 +730,7 @@ def serialize(self, file_path, compress=True):
if not unknown_token: # Store empty string instead of None
unknown_token = ''
else:
assert unknown_token == idx_to_token[C.UNK_IDX]
assert unknown_token == idx_to_token[UNK_IDX]

if not compress:
np.savez(file=file_path, unknown_token=unknown_token,
Expand Down Expand Up @@ -773,8 +776,8 @@ def deserialize(cls, file_path, **kwargs):

embedding = cls(unknown_token=unknown_token, **kwargs)
if unknown_token:
assert unknown_token == idx_to_token[C.UNK_IDX]
embedding._token_to_idx = DefaultLookupDict(C.UNK_IDX)
assert unknown_token == idx_to_token[UNK_IDX]
embedding._token_to_idx = DefaultLookupDict(UNK_IDX)
else:
embedding._token_to_idx = {}

Expand Down Expand Up @@ -1049,7 +1052,7 @@ def _load_w2v_binary(self, pretrained_file_path, encoding='utf8'):
"""
self._idx_to_token = [self.unknown_token] if self.unknown_token else []
if self.unknown_token:
self._token_to_idx = DefaultLookupDict(C.UNK_IDX)
self._token_to_idx = DefaultLookupDict(UNK_IDX)
else:
self._token_to_idx = {}
self._token_to_idx.update((token, idx) for idx, token in enumerate(self._idx_to_token))
Expand Down Expand Up @@ -1109,9 +1112,9 @@ def _load_w2v_binary(self, pretrained_file_path, encoding='utf8'):

if self.unknown_token:
if loaded_unknown_vec is None:
self._idx_to_vec[C.UNK_IDX] = self._init_unknown_vec(shape=vec_len)
self._idx_to_vec[UNK_IDX] = self._init_unknown_vec(shape=vec_len)
else:
self._idx_to_vec[C.UNK_IDX] = nd.array(loaded_unknown_vec)
self._idx_to_vec[UNK_IDX] = nd.array(loaded_unknown_vec)

@classmethod
def from_w2v_binary(cls, pretrained_file_path, encoding='utf8'):
Expand Down
Loading

0 comments on commit 1e50a66

Please sign in to comment.