Skip to content

Commit

Permalink
convert Vocab & related data items to use dataclasses
Browse files Browse the repository at this point in the history
  • Loading branch information
gojomo committed May 10, 2020
1 parent 2f29d74 commit 28bff70
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 22 deletions.
30 changes: 20 additions & 10 deletions gensim/models/doc2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@

from collections import namedtuple, defaultdict, Iterable
from timeit import default_timer
from dataclasses import dataclass

from numpy import zeros, float32 as REAL, ones, \
memmap as np_memmap, vstack, integer, dtype
Expand Down Expand Up @@ -144,21 +145,30 @@ def __str__(self):
return '%s(%s, %s)' % (self.__class__.__name__, self.words, self.tags)


class Doctag(namedtuple('Doctag', 'index, word_count, doc_count')):
"""A string document tag discovered during the initial vocabulary scan.
The document-vector equivalent of a Vocab object. TODO: merge with Vocab
@dataclass
class DoctagVocab:
"""A dataclass shape-compatible with keyedvectors.SimpleVocab, extended to record
details of string document tags discovered during the initial vocabulary scan.
Will not be used if all presented document tags are ints.
"""
__slots__ = ()

def repeat(self, word_count):
return self._replace(word_count=self.word_count + word_count, doc_count=self.doc_count + 1)
__slots__ = ('doc_count', 'index', 'word_count')
doc_count: int # number of docs where tag appeared
index: int # position in underlying array
word_count: int # number of words in associated docs

@property
def count(self):
return self.doc_count

@count.setter
def count(self, new_val):
self.doc_count = new_val


# compatibility alias, allowing prior namedtuples to unpickle
Doctag = DoctagVocab


class Doc2Vec(BaseWordEmbeddingsModel):
def __init__(self, documents=None, corpus_file=None, dm_mean=None, dm=1, dbow_words=0, dm_concat=0,
Expand Down Expand Up @@ -1029,7 +1039,8 @@ def _scan_vocab(self, documents, docvecs, progress_per, trim_rule):
max_rawint = max(max_rawint, tag)
else:
if tag in doctags_lookup:
doctags_lookup[tag] = doctags_lookup[tag].repeat(document_length)
doctags_lookup[tag].doc_count += 1
doctags_lookup[tag].word_count += document_length
else:
doctags_lookup[tag] = Doctag(index=len(doctags_list), word_count=document_length, doc_count=1)
doctags_list.append(tag)
Expand All @@ -1045,8 +1056,7 @@ def _scan_vocab(self, documents, docvecs, progress_per, trim_rule):
if max_rawint > -1:
# adjust indexes/list to account for range of pure-int keyed doctags
for key in doctags_list:
orig = doctags_lookup[key]
doctags_lookup[key] = orig._replace(index=orig.index + max_rawint + 1)
doctags_lookup[key].index = doctags_lookup[key].index + max_rawint + 1
doctags_list = ConcatList([range(0, max_rawint + 1), doctags_list])

docvecs.vocab = doctags_lookup
Expand Down
27 changes: 24 additions & 3 deletions gensim/models/keyedvectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@
from itertools import chain
import logging
from collections import UserList
from dataclasses import dataclass
from numbers import Integral

try:
Expand Down Expand Up @@ -347,7 +348,7 @@ def add(self, keys, weights, replace=False):
# add new entities to the vocab
for idx in np.nonzero(~in_vocab_mask)[0]:
key = keys[idx]
self.map[key] = Vocab(index=len(self.index2key), count=1)
self.map[key] = SimpleVocab(index=len(self.index2key), count=1)
self.index2key.append(key)

# add vectors for new entities
Expand Down Expand Up @@ -1407,11 +1408,27 @@ def _l2_norm(m, replace=False):
return (m / dist).astype(REAL)


class Vocab(object):
@dataclass
class SimpleVocab:
"""A single vocabulary item, used internally for collecting per-word position in the
backing array (.index), and frequency/sampling info from a corpus survey (.count).
Using a dataclass with fixed __slots__ saves 200+ bytes per entry over the prior
approach (which used a freely-expandable __dict__) – but now requires specialized
uses to define their own expanded data items, which should always include `count`
and `index` properties.
"""
__slots__ = ('count', 'index')
count: int
index: int


class CompatVocab(object):
def __init__(self, **kwargs):
"""A single vocabulary item, used internally for collecting per-word frequency/sampling info,
and for constructing binary trees (incl. both word leaves and inner nodes).
Retained for now to ease the loading of older models.
"""
self.count = 0
self.__dict__.update(kwargs)
Expand All @@ -1424,6 +1441,10 @@ def __str__(self):
return "%s(%s)" % (self.__class__.__name__, ', '.join(vals))


# compatibility alias, allowing older pickle-based `.save()`s to load
Vocab = CompatVocab


# Functions for internal use by _load_word2vec_format function

def _add_word_to_result(result, counts, word, weights, vocab_size):
Expand All @@ -1442,7 +1463,7 @@ def _add_word_to_result(result, counts, word, weights, vocab_size):
logger.warning("vocabulary file is incomplete: '%s' is missing", word)
word_count = None

result.vocab[word] = Vocab(index=word_id, count=word_count)
result.vocab[word] = SimpleVocab(index=word_id, count=word_count)
result.vectors[word_id] = weights
result.index2key.append(word)

Expand Down
57 changes: 51 additions & 6 deletions gensim/models/word2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,13 +126,15 @@
import heapq
from timeit import default_timer
from copy import deepcopy
from collections import defaultdict
from collections import defaultdict, namedtuple
from dataclasses import dataclass
from typing import List
import threading
import itertools
import warnings

from gensim.utils import keep_vocab_item, call_on_class_only, deprecated
from gensim.models.keyedvectors import Vocab, KeyedVectors, pseudorandom_weak_vector
from gensim.models.keyedvectors import KeyedVectors, pseudorandom_weak_vector
from gensim.models.base_any2vec import BaseWordEmbeddingsModel

try:
Expand Down Expand Up @@ -1039,6 +1041,41 @@ def _scan_vocab_worker(stream, progress_queue, max_vocab_size=None, trim_rule=No
return vocab


@dataclass
class W2VVocab:
"""A dataclass shape-compatible with keyedvectors.SimpleVocab, extended with the
`sample_int` property needed by `Word2Vec` models."""
__slots__ = ('count', 'index', 'sample_int')
count: int
index: int
sample_int: int

def __init__(self, count=0, index=0, sample_int=2**32):
self.count, self.index, self.sample_int = count, index, sample_int

def __lt__(self, other):
return self.count < other.count


@dataclass
class W2VHSVocab:
"""A dataclass shape-compatible with W2VVocab, extended with the `code` and
`point` properties needed by hierarchical-sampling (`hs=1`) `Word2Vec` models."""
__slots__ = ('count', 'index', 'sample_int', 'code', 'point')
count: int
index: int
sample_int: int
code: List[int]
point: List[int]

def __init__(self, count=0, index=0, sample_int=2**32, code=None, point=None):
self.count, self.index, self.sample_int, self.code, self.point = \
count, index, sample_int, code, point

def __lt__(self, other):
return self.count < other.count


class Word2VecVocab(utils.SaveLoad):
def __init__(
self, max_vocab_size=None, min_count=5, sample=1e-3, sorted_vocab=True, null_word=0,
Expand Down Expand Up @@ -1161,7 +1198,7 @@ def prepare_vocab(
retain_words.append(word)
retain_total += v
if not dry_run:
wv.vocab[word] = Vocab(count=v, index=len(wv.index2key))
wv.vocab[word] = W2VVocab(count=v, index=len(wv.index2key))
wv.index2key.append(word)
else:
drop_unique += 1
Expand Down Expand Up @@ -1193,7 +1230,7 @@ def prepare_vocab(
new_words.append(word)
new_total += v
if not dry_run:
wv.vocab[word] = Vocab(count=v, index=len(wv.index2key))
wv.vocab[word] = W2VVocab(count=v, index=len(wv.index2key))
wv.index2key.append(word)
else:
drop_unique += 1
Expand Down Expand Up @@ -1267,7 +1304,7 @@ def prepare_vocab(
return report_values

def add_null_word(self, wv):
word, v = '\0', Vocab(count=1, sample_int=0)
word, v = '\0', W2VVocab(count=1, sample_int=0)
v.index = len(wv.vocab)
wv.index2key.append(word)
wv.vocab[word] = v
Expand Down Expand Up @@ -1305,13 +1342,18 @@ def make_cum_table(self, wv, domain=2**31 - 1):
assert self.cum_table[-1] == domain


class Heapitem(namedtuple('Heapitem', 'count, index, left, right')):
def __lt__(self, other):
return self.count < other.count


def _build_heap(vocab):
heap = list(itervalues(vocab))
heapq.heapify(heap)
for i in range(len(vocab) - 1):
min1, min2 = heapq.heappop(heap), heapq.heappop(heap)
heapq.heappush(
heap, Vocab(count=min1.count + min2.count, index=i + len(vocab), left=min1, right=min2)
heap, Heapitem(count=min1.count + min2.count, index=i + len(vocab), left=min1, right=min2)
)
return heap

Expand All @@ -1338,6 +1380,9 @@ def _assign_binary_codes(vocab):
"""
logger.info("constructing a huffman tree from %i words", len(vocab))

for k in vocab.keys():
# ensure dataclass items sufficient for huffman-encoding
vocab[k] = W2VHSVocab(vocab[k].count, vocab[k].index, vocab[k].sample_int)
heap = _build_heap(vocab)
if not heap:
#
Expand Down
5 changes: 2 additions & 3 deletions gensim/models/wrappers/varembed.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@
import numpy as np

from gensim import utils
from gensim.models.keyedvectors import KeyedVectors
from gensim.models.word2vec import Vocab
from gensim.models.keyedvectors import KeyedVectors, SimpleVocab

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -99,7 +98,7 @@ def load_word_embeddings(self, word_embeddings, word_to_ix):
self.index2word = [None] * self.vocab_size
logger.info("Corpus has %i words", len(self.vocab))
for word_id, word in enumerate(counts):
self.vocab[word] = Vocab(index=word_id, count=counts[word])
self.vocab[word] = SimpleVocab(index=word_id, count=counts[word])
self.vectors[word_id] = word_embeddings[word_to_ix[word]]
self.index2word[word_id] = word
assert((len(self.vocab), self.vector_size) == self.vectors.shape)
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,7 @@ def run(self):
'scipy >= 0.18.1',
'six >= 1.5.0',
'smart_open >= 1.8.1',
"dataclasses; python_version < '3.7'",
]

setup_requires = [NUMPY_STR]
Expand Down

0 comments on commit 28bff70

Please sign in to comment.