Skip to content

Commit

Permalink
piskvorky#1342: Optimize data structures being used for window set tr…
Browse files Browse the repository at this point in the history
…acking and avoid undue network traffic by moving relevancy filtering and token conversion to the master process.
  • Loading branch information
Sweeney, Mack committed May 31, 2017
1 parent e785773 commit 5f78cdb
Showing 1 changed file with 70 additions and 53 deletions.
123 changes: 70 additions & 53 deletions gensim/topic_coherence/text_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class BaseAnalyzer(object):

def __init__(self, relevant_ids):
self.relevant_ids = relevant_ids
self._vocab_size = len(self.relevant_ids)
self.id2contiguous = {word_id: n for n, word_id in enumerate(self.relevant_ids)}
self.log_every = 1000
self._num_docs = 0
Expand Down Expand Up @@ -92,7 +93,8 @@ def _get_co_occurrences(self, word_id1, word_id2):

class UsesDictionary(BaseAnalyzer):
"""A BaseAnalyzer that uses a Dictionary, hence can translate tokens to counts.
The standard BaseAnalyzer can only deal with token ids since it does not have access to the token2id mapping.
The standard BaseAnalyzer can only deal with token ids since it doesn't have the token2id
mapping.
"""

def __init__(self, relevant_ids, dictionary):
Expand Down Expand Up @@ -128,8 +130,7 @@ class InvertedIndexBased(BaseAnalyzer):

def __init__(self, *args):
super(InvertedIndexBased, self).__init__(*args)
vocab_size = len(self.relevant_ids)
self._inverted_index = np.array([set() for _ in range(vocab_size)])
self._inverted_index = np.array([set() for _ in range(self._vocab_size)])

def _get_occurrences(self, word_id):
return len(self._inverted_index[word_id])
Expand Down Expand Up @@ -169,15 +170,10 @@ def __init__(self, relevant_ids, dictionary):
Args:
----
relevant_ids: the set of words that occurrences should be accumulated for.
dictionary: gensim.corpora.dictionary.Dictionary instance with mappings for the relevant_ids.
dictionary: Dictionary instance with mappings for the relevant_ids.
"""
super(WindowedTextsAnalyzer, self).__init__(relevant_ids, dictionary)

def filter_to_relevant_words(self, text):
"""Lazily filter the text to only those words which are relevant."""
relevant_words = (word for word in text if word in self.relevant_words)
relevant_ids = (self.token2id[word] for word in relevant_words)
return (self.id2contiguous[word_id] for word_id in relevant_ids)
self._none_token = self._vocab_size # see _iter_texts for use of none token

def accumulate(self, texts, window_size):
relevant_texts = self._iter_texts(texts)
Expand All @@ -189,11 +185,13 @@ def accumulate(self, texts, window_size):
return self

def _iter_texts(self, texts):
dtype = np.uint16 if np.iinfo(np.uint16).max >= self._vocab_size else np.uint32
for text in texts:
if self.text_is_relevant(text):
token_ids = (self.token2id[word] if word in self.relevant_words else None
for word in text)
yield [self.id2contiguous[_id] if _id is not None else None for _id in token_ids]
yield np.array([
self.id2contiguous[self.token2id[w]] if w in self.relevant_words
else self._none_token
for w in text], dtype=dtype)

def text_is_relevant(self, text):
"""Return True if the text has any relevant words, else False."""
Expand All @@ -208,7 +206,7 @@ class InvertedIndexAccumulator(WindowedTextsAnalyzer, InvertedIndexBased):

def analyze_text(self, window, doc_num=None):
for word_id in window:
if word_id is not None:
if word_id is not self._none_token:
self._inverted_index[word_id].add(self._num_docs)


Expand All @@ -217,9 +215,11 @@ class WordOccurrenceAccumulator(WindowedTextsAnalyzer):

def __init__(self, *args):
super(WordOccurrenceAccumulator, self).__init__(*args)
vocab_size = len(self.relevant_words)
self._occurrences = np.zeros(vocab_size, dtype='uint32')
self._co_occurrences = sps.lil_matrix((vocab_size, vocab_size), dtype='uint32')
self._occurrences = np.zeros(self._vocab_size, dtype='uint32')
self._co_occurrences = sps.lil_matrix((self._vocab_size, self._vocab_size), dtype='uint32')

self._uniq_words = np.zeros((self._vocab_size + 1,), dtype=bool) # add 1 for none token
self._mask = self._uniq_words[:-1] # to exclude none token

def __str__(self):
return self.__class__.__name__
Expand All @@ -242,25 +242,23 @@ def partial_accumulate(self, texts, window_size):
return self

def analyze_text(self, window, doc_num=None):
self.slide_window(window, doc_num)
if self._mask.any():
self._occurrences[self._mask] += 1

for combo in itertools.combinations(np.nonzero(mask)[0], 2):
self._co_occurrences[combo] += 1

def slide_window(self, window, doc_num):
if doc_num != self._current_doc_num:
self._uniq_words = set(window)
self._uniq_words.discard(None)
self._token_at_edge = window[0]
self._uniq_words[:] = False
self._uniq_words[np.unique(window)] = True
self._current_doc_num = doc_num
else:
if self._token_at_edge is not None:
self._uniq_words.discard(self._token_at_edge) # may be irrelevant token
self._token_at_edge = window[0]
self._uniq_words[self._token_at_edge] = False
self._uniq_words[window[-1]] = True

if window[-1] is not None:
self._uniq_words.add(window[-1])

if self._uniq_words:
words_idx = np.array(list(self._uniq_words))
self._occurrences[words_idx] += 1

for combo in itertools.combinations(words_idx, 2):
self._co_occurrences[combo] += 1
self._token_at_edge = window[0]

def _symmetrize(self):
"""Word pairs may have been encountered in (i, j) and (j, i) order.
Expand All @@ -283,15 +281,31 @@ def merge(self, other):
self._num_docs += other._num_docs


class PatchedWordOccurrenceAccumulator(WordOccurrenceAccumulator):
"""Monkey patched for multiprocessing worker usage,
to move some of the logic to the master process.
"""
def _iter_texts(self, texts):
return texts # master process will handle this


class ParallelWordOccurrenceAccumulator(WindowedTextsAnalyzer):
"""Accumulate word occurrences in parallel."""

def __init__(self, processes, *args, **kwargs):
"""
Args:
----
processes : number of processes to use; must be at least two.
args : should include `relevant_ids` and `dictionary` (see `UsesDictionary.__init__`).
kwargs : can include `batch_size`, which is the number of docs to send to a worker at a
time. If not included, it defaults to 32.
"""
super(ParallelWordOccurrenceAccumulator, self).__init__(*args)
if processes < 2:
raise ValueError("Must have at least 2 processes to run in parallel; got %d" % processes)
raise ValueError("Must have at least 2 processes to run in parallel; got %d", processes)
self.processes = processes
self.batch_size = kwargs.get('batch_size', 16)
self.batch_size = kwargs.get('batch_size', 32)

def __str__(self):
return "%s(processes=%s, batch_size=%s)" % (
Expand All @@ -303,7 +317,8 @@ def accumulate(self, texts, window_size):
self.queue_all_texts(input_q, texts, window_size)
interrupted = False
except KeyboardInterrupt:
logger.warn("stats accumulation interrupted; <= %d documents processed" % self._num_docs)
logger.warn("stats accumulation interrupted; <= %d documents processed",
self._num_docs)
interrupted = True

accumulators = self.terminate_workers(input_q, output_q, workers, interrupted)
Expand All @@ -320,7 +335,7 @@ def start_workers(self, window_size):
output_q = mp.Queue()
workers = []
for _ in range(self.processes):
accumulator = WordOccurrenceAccumulator(self.relevant_ids, self.dictionary)
accumulator = PatchedWordOccurrenceAccumulator(self.relevant_ids, self.dictionary)
worker = AccumulatingWorker(input_q, output_q, accumulator, window_size)
worker.start()
workers.append(worker)
Expand All @@ -332,7 +347,7 @@ def yield_batches(self, texts):
`batch_size` texts at a time.
"""
batch = []
for text in texts:
for text in self._iter_texts(texts):
batch.append(text)
if len(batch) == self.batch_size:
yield batch
Expand All @@ -345,14 +360,14 @@ def queue_all_texts(self, q, texts, window_size):
"""Sequentially place batches of texts on the given queue until `texts` is consumed.
The texts are filtered so that only those with at least one relevant token are queued.
"""
relevant_texts = (text for text in texts if self.text_is_relevant(text))
for batch_num, batch in enumerate(self.yield_batches(relevant_texts)):
for batch_num, batch in enumerate(self.yield_batches(texts)):
q.put(batch, block=True)
before = self._num_docs / self.log_every
self._num_docs += sum(len(doc) - window_size + 1 for doc in batch)
if before < (self._num_docs / self.log_every):
logger.info("submitted %d batches to accumulate stats from %d documents (%d virtual)" % (
batch_num, (batch_num + 1) * self.batch_size, self._num_docs))
logger.info("%d batches submitted to accumulate stats from %d documents (%d "
"virtual)",
(batch_num + 1), (batch_num + 1) * self.batch_size, self._num_docs)

def terminate_workers(self, input_q, output_q, workers, interrupted=False):
"""Wait until all workers have transmitted their WordOccurrenceAccumulator instances,
Expand Down Expand Up @@ -392,10 +407,10 @@ def merge_accumulators(self, accumulators):
accumulator = accumulators[0]
for other_accumulator in accumulators[1:]:
accumulator.merge(other_accumulator)
# Workers perform partial accumulation, so none of the co-occurrence matrices are symmetrized.
# This is by design, to avoid unnecessary matrix additions during accumulation.
# Workers do partial accumulation, so none of the co-occurrence matrices are symmetrized.
# This is by design, to avoid unnecessary matrix additions/conversions during accumulation.
accumulator._symmetrize()
logger.info("accumulated word occurrence stats for %d virtual documents" %
logger.info("accumulated word occurrence stats for %d virtual documents",
accumulator.num_docs)
return accumulator

Expand All @@ -414,30 +429,32 @@ def __init__(self, input_q, output_q, accumulator, window_size):
def run(self):
try:
self._run()
print("finished normally")
except KeyboardInterrupt:
logger.info("%s interrupted after processing %d documents" % (
self.__class__.__name__, self.accumulator.num_docs))
logger.info("%s interrupted after processing %d documents",
self.__class__.__name__, self.accumulator.num_docs)
except Exception as e:
logger.error("worker encountered unexpected exception: %s" % e)
logger.error(traceback.format_exc())
logger.error("worker encountered unexpected exception: %s\n%s",
e, traceback.format_exc())
finally:
self.reply_to_master()

def _run(self):
batch_num = 0
batch_num = -1
n_docs = 0
while True:
batch_num += 1
docs = self.input_q.get(block=True)
if docs is None: # sentinel value
logger.debug("observed sentinel value; terminating")
break

self.accumulator.partial_accumulate(docs, self.window_size)
n_docs += len(docs)
logger.debug("completed batch %d; %d documents processed (%d virtual)" % (
batch_num, n_docs, self.accumulator.num_docs))
batch_num += 1
logger.debug("completed batch %d; %d documents processed (%d virtual)",
batch_num, n_docs, self.accumulator.num_docs)

logger.debug("finished all batches; %d documents processed (%d virtual)",
n_docs, self.accumulator.num_docs)

def reply_to_master(self):
logger.info("serializing accumulator to return to master...")
Expand Down

0 comments on commit 5f78cdb

Please sign in to comment.