diff --git a/gensim/summarization/bm25.py b/gensim/summarization/bm25.py index ec484949cf..3a2bf5bbf6 100644 --- a/gensim/summarization/bm25.py +++ b/gensim/summarization/bm25.py @@ -4,7 +4,7 @@ # Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html """This module contains function of computing rank scores for documents in -corpus and helper class `BM25` used in calculations. Original alhorithm +corpus and helper class `BM25` used in calculations. Original algorithm descibed in [1]_, also you may check Wikipedia page [2]_. @@ -61,7 +61,8 @@ class BM25(object): Dictionary with terms frequencies for whole `corpus`. Words used as keys and frequencies as values. idf : dict Dictionary with inversed terms frequencies for whole `corpus`. Words used as keys and frequencies as values. - + doc_len : list of int + List of document lengths. """ def __init__(self, corpus): @@ -78,12 +79,14 @@ def __init__(self, corpus): self.f = [] self.df = {} self.idf = {} + self.doc_len = [] self.initialize() def initialize(self): """Calculates frequencies of terms in documents and in corpus. Also computes inverse document frequencies.""" for document in self.corpus: frequencies = {} + self.doc_len.append(len(document)) for word in document: if word not in frequencies: frequencies[word] = 0 @@ -122,7 +125,7 @@ def get_score(self, document, index, average_idf): continue idf = self.idf[word] if self.idf[word] >= 0 else EPSILON * average_idf score += (idf * self.f[index][word] * (PARAM_K1 + 1) - / (self.f[index][word] + PARAM_K1 * (1 - PARAM_B + PARAM_B * len(document) / self.avgdl))) + / (self.f[index][word] + PARAM_K1 * (1 - PARAM_B + PARAM_B * self.doc_len[index] / self.avgdl))) return score def get_scores(self, document, average_idf): diff --git a/gensim/test/test_BM25.py b/gensim/test/test_BM25.py new file mode 100644 index 0000000000..a96302e8c9 --- /dev/null +++ b/gensim/test/test_BM25.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (C) 2010 Radim Rehurek +# Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html + +""" +Automated tests for checking transformation algorithms (the models package). +""" + +import logging +import unittest + +from gensim.summarization.bm25 import get_bm25_weights +from gensim.test.utils import common_texts + + +class TestBM25(unittest.TestCase): + def test_max_match_with_itself(self): + """ Document should show maximum matching with itself """ + weights = get_bm25_weights(common_texts) + for index, doc_weights in enumerate(weights): + expected = max(doc_weights) + predicted = doc_weights[index] + self.assertAlmostEqual(expected, predicted) + + def test_nonnegative_weights(self): + """ All the weights for a partiular document should be non negative """ + weights = get_bm25_weights(common_texts) + for doc_weights in weights: + for weight in doc_weights: + self.assertTrue(weight >= 0.) + + def test_same_match_with_same_document(self): + """ A document should always get the same weight when matched with a particular document """ + corpus = [['cat', 'dog', 'mouse'], ['cat', 'lion'], ['cat', 'lion']] + weights = get_bm25_weights(corpus) + self.assertAlmostEqual(weights[0][1], weights[0][2]) + + def test_disjoint_docs_if_weight_zero(self): + """ Two disjoint documents should have zero matching""" + corpus = [['cat', 'dog', 'lion'], ['goat', 'fish', 'tiger']] + weights = get_bm25_weights(corpus) + self.assertAlmostEqual(weights[0][1], 0) + self.assertAlmostEqual(weights[1][0], 0) + + +if __name__ == '__main__': + logging.basicConfig(level=logging.DEBUG) + unittest.main()