Skip to content

Commit

Permalink
feat(search): implement BM25 similarity algorithm and refactor Simila…
Browse files Browse the repository at this point in the history
…rChunkSearcher

The commit introduces the BM25 similarity algorithm for text matching, adds comprehensive tests for BM25Similarity, and refactors the SimilarChunkSearcher to use the new computeInputSimilarity method. The JaccardSimilarity class is updated to align with the new Similarity interface, and the StopwordsBasedTokenizer is now a shared utility.
  • Loading branch information
phodal committed Jul 17, 2024
1 parent 27c65c1 commit ef32475
Show file tree
Hide file tree
Showing 6 changed files with 160 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class SimilarChunksSearch(private var snippetLength: Int = 60, private var maxRe
val mostRecentFilesRelativePaths = mostRecentFiles.mapNotNull { relativePathTo(it, element) }

val chunks = extractChunks(element, mostRecentFiles)
val jaccardSimilarities = tokenLevelJaccardSimilarity(element.text, chunks)
val jaccardSimilarities = computeInputSimilarity(element.text, chunks)

val similarChunks: List<Pair<String, String>> =
jaccardSimilarities.mapIndexedNotNull { fileIndex, jaccardList ->
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package com.phodal.shirecore.search.algorithm

import kotlin.math.log10

class BM25Similarity : Similarity {
private val k1 = 1.5
private val b = 0.75

override fun computeInputSimilarity(query: String, chunks: List<List<String>>): List<List<Double>> {
val docCount = chunks.size
val avgDocLength = chunks.map { it.size }.average()
val idfMap = computeIDF(chunks, docCount)

// Tokenize the query
val queryTerms = tokenize(query)

return chunks.map { doc ->
val docLength = doc.size
queryTerms.map { term ->
val tf = doc.count { it == term }.toDouble()
val idf = idfMap[term] ?: 0.0
val numerator = tf * (k1 + 1)
val denominator = tf + k1 * (1 - b + b * (docLength / avgDocLength))
idf * (numerator / denominator)
}
}
}

fun computeIDF(chunks: List<List<String>>, docCount: Int): Map<String, Double> {
val termDocCount = mutableMapOf<String, Int>()

chunks.forEach { doc ->
doc.toSet().forEach { term ->
termDocCount[term] = termDocCount.getOrDefault(term, 0) + 1
}
}

return termDocCount.mapValues { (_, count) ->
log10((docCount - count + 0.5) / (count + 0.5) + 1.0)
}
}
}

Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
package com.phodal.shirecore.search.algorithm

import com.phodal.shirecore.search.tokenizer.StopwordsBasedTokenizer

open class JaccardSimilarity : Similarity {
/**
* The `tokenLevelJaccardSimilarity` method calculates the Jaccard similarity between a query string and an array of string
Expand All @@ -12,7 +10,7 @@ open class JaccardSimilarity : Similarity {
* @param chunks An array of string arrays (chunks) to compare against the query.
* @return A two-dimensional array representing the Jaccard similarity scores between the query and each chunk.
*/
fun tokenLevelJaccardSimilarity(query: String, chunks: List<List<String>>): List<List<Double>> {
override fun computeInputSimilarity(query: String, chunks: List<List<String>>): List<List<Double>> {
val currentFileTokens = tokenize(query)
return chunks.map { list ->
list.map { it ->
Expand All @@ -22,10 +20,6 @@ open class JaccardSimilarity : Similarity {
}
}

fun tokenize(input: String): Set<String> {
return StopwordsBasedTokenizer.instance().tokenize(input)
}

fun similarityScore(set1: Set<String>, set2: Set<String>): Double {
val intersectionSize = set1.intersect(set2).size
val unionSize = set1.union(set2).size
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
package com.phodal.shirecore.search.algorithm

import com.phodal.shirecore.search.tokenizer.StopwordsBasedTokenizer

interface Similarity {
fun tokenize(input: String): Set<String> {
return StopwordsBasedTokenizer.instance().tokenize(input)
}

fun computeInputSimilarity(query: String, chunks: List<List<String>>): List<List<Double>>
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
package com.phodal.shirecore.search.algorithm

import junit.framework.TestCase.*
import org.junit.Test
import kotlin.math.log10

/**
* Unit tests for BM25Similarity class.
*/
class BM25SimilarityTest {
@Test
fun testComputeInputSimilarity() {
val bm25 = BM25Similarity()

val query = "sample query"
val chunks = listOf(
listOf("this", "is", "a", "sample", "document"),
listOf("this", "document", "is", "another", "sample"),
listOf("one", "more", "sample", "document")
)

val similarity = bm25.computeInputSimilarity(query, chunks)

assertNotNull(similarity)
assertEquals(3, similarity.size) // We have 3 documents
similarity.forEach { docSim ->
assertEquals(2, docSim.size) // Our query has 2 terms
}

// Print similarity for manual inspection (not a part of actual tests)
similarity.forEachIndexed { index, docSim ->
println("Document $index: $docSim")
}
}

@Test
fun testComputeIDF() {
val bm25 = BM25Similarity()

val chunks = listOf(
listOf("this", "is", "a", "sample", "document"),
listOf("this", "document", "is", "another", "sample"),
listOf("one", "more", "sample", "document")
)
val docCount = chunks.size

val idfMap = bm25.computeIDF(chunks, docCount)

assertNotNull(idfMap)
assertTrue(idfMap.isNotEmpty())
assertEquals(8, idfMap.size) // There are 8 unique terms

// Validate some IDF values manually
val expectedIDFThis = log10((docCount - 2 + 0.5) / (2 + 0.5) + 1.0)
val expectedIDFSample = log10((docCount - 3 + 0.5) / (3 + 0.5) + 1.0)
val expectedIDFOne = log10((docCount - 1 + 0.5) / (1 + 0.5) + 1.0)

assertEquals(expectedIDFThis, idfMap["this"])
assertEquals(expectedIDFSample, idfMap["sample"])
assertEquals(expectedIDFOne, idfMap["one"])

// Print IDF map for manual inspection (not a part of actual tests)
idfMap.forEach { (term, idf) ->
println("Term: $term, IDF: $idf")
}
}

@Test
fun `should compute similarity for query and documents correctly`() {
val similarity = BM25Similarity()
val chunks = listOf(
listOf("apple", "banana", "apple"),
listOf("banana", "cherry"),
listOf("apple", "cherry")
)
val query = "apple banana"

val result = similarity.computeInputSimilarity(query, chunks)

assertNotNull(result)
assertEquals(3, result.size) // Ensure one result per document
assertEquals(2, result[0].size) // Ensure one score per term in the query

// Check if the computed similarity values are non-negative
val allPositive = result.flatten().all { it >= 0.0 }
assertTrue(allPositive)
}

@Test
fun `should handle query term not present in documents`() {
val similarity = BM25Similarity()
val chunks = listOf(
listOf("apple", "banana", "apple"),
listOf("banana", "cherry"),
listOf("apple", "cherry")
)
val query = "apple orange"

val result = similarity.computeInputSimilarity(query, chunks)

assertNotNull(result)
assertEquals(3, result.size) // Ensure one result per document
assertEquals(2, result[0].size) // Ensure one score per term in the query

// Check if the score for the term 'orange' is 0.0 as it is not present in the documents
assertEquals(0.0, result[0][1])
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class JaccardSimilarityTest {
val chunks = listOf(listOf("test", "query"), listOf("another", "test"), listOf("query", "test"))

// when
val result = jaccardSimilarity.tokenLevelJaccardSimilarity(query, chunks)
val result = jaccardSimilarity.computeInputSimilarity(query, chunks)

// then
assertEquals(3, result.size)
Expand Down

0 comments on commit ef32475

Please sign in to comment.