Skip to content

Commit

Permalink
Make most_similar accept any integral topn
Browse files Browse the repository at this point in the history
  • Loading branch information
Witiko committed May 17, 2019
1 parent d910fa9 commit b34b06e
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 4 deletions.
7 changes: 4 additions & 3 deletions gensim/models/keyedvectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@

from itertools import chain
import logging
from numbers import Integral

try:
from queue import Queue, Empty
Expand Down Expand Up @@ -519,7 +520,7 @@ def most_similar(self, positive=None, negative=None, topn=10, restrict_vocab=Non
one-dimensional numpy array with the size of the vocabulary.
"""
if isinstance(topn, int) and topn < 1:
if isinstance(topn, Integral) and topn < 1:
return []

if positive is None:
Expand Down Expand Up @@ -807,7 +808,7 @@ def most_similar_cosmul(self, positive=None, negative=None, topn=10):
one-dimensional numpy array with the size of the vocabulary.
"""
if isinstance(topn, int) and topn < 1:
if isinstance(topn, Integral) and topn < 1:
return []

if positive is None:
Expand Down Expand Up @@ -1678,7 +1679,7 @@ def most_similar(self, positive=None, negative=None, topn=10, clip_start=0, clip
Sequence of (doctag/index, similarity).
"""
if isinstance(topn, int) and topn < 1:
if isinstance(topn, Integral) and topn < 1:
return []

if positive is None:
Expand Down
3 changes: 2 additions & 1 deletion gensim/models/poincare.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@

import csv
import logging
from numbers import Integral
import sys
import time

Expand Down Expand Up @@ -1219,7 +1220,7 @@ def most_similar(self, node_or_vector, topn=10, restrict_vocab=None):
[(u'kangaroo.n.01', 0.0), (u'marsupial.n.01', 0.26524229460827725)]
"""
if isinstance(topn, int) and topn < 1:
if isinstance(topn, Integral) and topn < 1:
return []

if not restrict_vocab:
Expand Down
3 changes: 3 additions & 0 deletions gensim/test/test_keyedvectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,9 @@ def test_most_similar_topn(self):
predicted = self.vectors.most_similar('war', topn=0)
self.assertEqual(len(predicted), 0)

predicted = self.vectors.most_similar('war', topn=np.uint8(0))
self.assertEqual(len(predicted), 0)

def test_relative_cosine_similarity(self):
"""Test relative_cosine_similarity returns expected results with an input of a word pair and topn"""
wordnet_syn = [
Expand Down

0 comments on commit b34b06e

Please sign in to comment.